From 6509c3f6d4081988b9e9bec0cedc7fcfc12e6ad7 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 5 Jun 2024 12:14:57 +0000 Subject: [PATCH] Implement RandomSubsetApply features --- lerobot/common/datasets/transforms.py | 68 ++++++++++++++++++++------- lerobot/configs/default.yaml | 2 +- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 11499a51..cbb7d5b3 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -8,32 +8,58 @@ from torchvision.transforms.v2 import functional as F # noqa: N812 class RandomSubsetApply(Transform): """ - Apply a random subset of N transformations from a list of transformations in a random order. + Apply a random subset of N transformations from a list of transformations. Args: transforms (sequence or torch.nn.Module): list of transformations - N (int): number of transformations to apply + p (list of floats or None, optional): probability of each transform being picked. + If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` + (default), all transforms have the same probability. + n_subset (int or None): number of transformations to apply. If ``None``, + all transforms are applied. + random_order (bool): apply transformations in a random order """ - def __init__(self, transforms: Sequence[Callable], n_subset: int) -> None: + def __init__( + self, + transforms: Sequence[Callable], + p: list[float] | None = None, + n_subset: int | None = None, + random_order: bool = False, + ) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") - if not (0 <= n_subset <= len(transforms)): - raise ValueError(f"N should be in the interval [0, {len(transforms)}]") + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError( + f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" + ) + + if n_subset is None: + n_subset = len(transforms) + elif not isinstance(n_subset, int): + raise TypeError("n_subset should be an int or None") + elif not (0 <= n_subset <= len(transforms)): + raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]") self.transforms = transforms - self.N = n_subset + self.n_subset = n_subset + self.random_order = random_order def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - # Randomly pick N transforms - selected_transforms = torch.randperm(len(self.transforms))[: self.N] + indices = torch.arange(len(self.transforms)) + selected_indices = torch.randperm(len(indices))[: self.n_subset] + if not self.random_order: + selected_indices = selected_indices.sort().values - # Apply selected transforms in random order - for idx in selected_transforms: - transform = self.transforms[idx] + selected_transforms = [self.transforms[i] for i in selected_indices] + print(selected_transforms) + + for transform in selected_transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) @@ -66,19 +92,29 @@ class RangeRandomSharpness(Transform): return range_min, range_max def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1) + sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1).item() + print(f"{sharpness_factor=}") return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) def make_transforms(cfg): - image_transforms = [ + transforms_list = [ v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)), v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)), RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max), ] - # WIP - return v2.Compose( - [RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)] + transforms_weights = [ + cfg.brightness.weight, + cfg.contrast.weight, + cfg.saturation.weight, + cfg.hue.weight, + cfg.sharpness.weight, + ] + + transforms = RandomSubsetApply( + transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order ) + + return v2.Compose([transforms, v2.ToDtype(torch.float32, scale=True)]) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 3c195c80..b0eb6009 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -59,7 +59,7 @@ wandb: notes: "" image_transform: - enable: false + enable: true # Maximum number of transforms to apply max_num_transforms: 3 random_order: false