diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index d9ce7a0f..370d1640 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -17,9 +17,9 @@ import logging import torch from omegaconf import ListConfig, OmegaConf -from torchvision.transforms import v2 from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.transforms import make_transforms def resolve_delta_timestamps(cfg): @@ -72,22 +72,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - if cfg.image_transform.enable: - transform = v2.Compose( - [ - v2.ColorJitter( - brightness=cfg.image_transform.colorjitter_factor, - contrast=cfg.image_transform.colorjitter_factor, - ), - v2.RandomAdjustSharpness( - cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p - ), - v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p), - v2.ToDtype(torch.float32, scale=True), - ] - ) - else: - transform = None + transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py new file mode 100644 index 00000000..e5bec5e6 --- /dev/null +++ b/lerobot/common/datasets/transforms.py @@ -0,0 +1,58 @@ +from typing import Any, Callable, Sequence + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.v2 import Transform + + +class RandomSubsetApply(Transform): + """ + Apply a random subset of N transformations from a list of transformations in a random order. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + N (int): number of transformations to apply + """ + + def __init__(self, transforms: Sequence[Callable], n_subset: int) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if not (0 <= n_subset <= len(transforms)): + raise ValueError(f"N should be in the interval [0, {len(transforms)}]") + + self.transforms = transforms + self.N = n_subset + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + # Randomly pick N transforms + selected_transforms = torch.randperm(len(self.transforms))[: self.N] + + # Apply selected transforms in random order + for idx in selected_transforms: + transform = self.transforms[idx] + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + + return outputs + + def extra_repr(self) -> str: + format_string = [f"N={self.N}"] + for t in self.transforms: + format_string.append(f" {t}") + return "\n".join(format_string) + + +def make_transforms(cfg): + image_transforms = [ + v2.ColorJitter( + brightness=cfg.colorjitter_factor, + contrast=cfg.colorjitter_factor, + ), + v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p), + v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p), + v2.ToDtype(torch.float32, scale=True), + ] + return RandomSubsetApply(image_transforms, n_subset=cfg.n_subset)