Updated config to match transforms
This commit is contained in:
commit
66629a956d
|
@ -17,9 +17,9 @@ import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import ListConfig, OmegaConf
|
from omegaconf import ListConfig, OmegaConf
|
||||||
from torchvision.transforms import v2
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import make_transforms
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(cfg):
|
def resolve_delta_timestamps(cfg):
|
||||||
|
@ -72,23 +72,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||||
|
|
||||||
resolve_delta_timestamps(cfg)
|
resolve_delta_timestamps(cfg)
|
||||||
|
|
||||||
if cfg.image_transform.enable:
|
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
|
||||||
transform = v2.Compose(
|
|
||||||
[
|
|
||||||
v2.ColorJitter(
|
|
||||||
brightness=cfg.image_transform.colorjitter_factor,
|
|
||||||
contrast=cfg.image_transform.colorjitter_factor,
|
|
||||||
),
|
|
||||||
v2.RandomAdjustSharpness(
|
|
||||||
cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p
|
|
||||||
),
|
|
||||||
# Using RandomAdjustSharpness with parameter < 1 adds blur to the image
|
|
||||||
v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
transform = None
|
|
||||||
|
|
||||||
if isinstance(cfg.dataset_repo_id, str):
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
from typing import Any, Callable, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torchvision.transforms.v2 import Transform
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSubsetApply(Transform):
|
||||||
|
"""
|
||||||
|
Apply a random subset of N transformations from a list of transformations in a random order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms (sequence or torch.nn.Module): list of transformations
|
||||||
|
N (int): number of transformations to apply
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms: Sequence[Callable], n_subset: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(transforms, Sequence):
|
||||||
|
raise TypeError("Argument transforms should be a sequence of callables")
|
||||||
|
if not (0 <= n_subset <= len(transforms)):
|
||||||
|
raise ValueError(f"N should be in the interval [0, {len(transforms)}]")
|
||||||
|
|
||||||
|
self.transforms = transforms
|
||||||
|
self.N = n_subset
|
||||||
|
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
needs_unpacking = len(inputs) > 1
|
||||||
|
|
||||||
|
# Randomly pick N transforms
|
||||||
|
selected_transforms = torch.randperm(len(self.transforms))[: self.N]
|
||||||
|
|
||||||
|
# Apply selected transforms in random order
|
||||||
|
for idx in selected_transforms:
|
||||||
|
transform = self.transforms[idx]
|
||||||
|
outputs = transform(*inputs)
|
||||||
|
inputs = outputs if needs_unpacking else (outputs,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
format_string = [f"N={self.N}"]
|
||||||
|
for t in self.transforms:
|
||||||
|
format_string.append(f" {t}")
|
||||||
|
return "\n".join(format_string)
|
||||||
|
|
||||||
|
|
||||||
|
def make_transforms(cfg):
|
||||||
|
image_transforms = []
|
||||||
|
if 'jit' in cfg.list_transforms:
|
||||||
|
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_range, contrast=cfg.colorjitter_range))
|
||||||
|
if 'sharpness' in cfg.list_transforms:
|
||||||
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
||||||
|
if 'blur' in cfg.list_transforms:
|
||||||
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
||||||
|
|
||||||
|
return v2.Compose(RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True))
|
|
@ -60,13 +60,14 @@ wandb:
|
||||||
|
|
||||||
image_transform:
|
image_transform:
|
||||||
enable: false
|
enable: false
|
||||||
colorjittor_factor: 0.5
|
colorjittor_range: (0, 1)
|
||||||
# Noise sampled randomly from 0 to colorjittor_factor
|
# Range from which to sample colorjittor factor
|
||||||
sharpness_factor: 2
|
sharpness_factor: 3
|
||||||
# Should be more than 1, setting parameter to 1 does not change the image
|
# Should be more than 1, setting parameter to 1 does not change the image
|
||||||
sharpness_p: 0.5
|
sharpness_p: 0.5
|
||||||
# Probability that Sharpness is applied
|
blur_factor: 0.2
|
||||||
blur_factor: 0.5
|
|
||||||
# Should be less than 1, setting parameter to 1 does not change the image
|
# Should be less than 1, setting parameter to 1 does not change the image
|
||||||
blur_p: 0.5
|
blur_p: 0.5
|
||||||
# Probability that Blur is applied
|
n_subset: 3
|
||||||
|
# Maximum number of transforms to apply
|
||||||
|
list: ["colorjittor", "sharpness", "blur"]
|
||||||
|
|
Loading…
Reference in New Issue