diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index e4ba980c..11499a51 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -1,8 +1,9 @@ -from typing import Any, Callable, Sequence +from typing import Any, Callable, Dict, Sequence import torch from torchvision.transforms import v2 from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import functional as F # noqa: N812 class RandomSubsetApply(Transform): @@ -45,17 +46,39 @@ class RandomSubsetApply(Transform): return "\n".join(format_string) -def make_transforms(cfg): - image_transforms = [] - if "colorjitter" in cfg.list: - image_transforms.append( - v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor) - ) - if "sharpness" in cfg.list: - image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p)) - if "blur" in cfg.list: - image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p)) +class RangeRandomSharpness(Transform): + """Similar to RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly + each time in [range_min, range_max]. + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + """ + + def __init__(self, range_min: float, range_max) -> None: + super().__init__() + self.range_min, self.range_max = self._check_input(range_min, range_max) + + def _check_input(self, range_min, range_max): + if range_min < 0: + raise ValueError("range_min must be non negative.") + if range_min > range_max: + raise ValueError("range_max must greater or equal to range_min") + return range_min, range_max + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + + +def make_transforms(cfg): + image_transforms = [ + v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), + v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), + v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)), + v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)), + RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max), + ] + # WIP return v2.Compose( [RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)] ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 04c687bd..3c195c80 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -60,14 +60,26 @@ wandb: image_transform: enable: false - colorjitter_factor: 0.5 - colorjitter_p: 1.O - sharpness_factor: 3.0 - # Should be more than 1, setting parameter to 1 does not change the image - sharpness_p: 0.5 - blur_factor: 0.2 - # Should be less than 1, setting parameter to 1 does not change the image - blur_p: 0.5 - n_subset: 3 # Maximum number of transforms to apply - list: ["colorjitter", "sharpness", "blur"] + max_num_transforms: 3 + random_order: false + brightness: + weight: 1 + min: 0.5 + max: 1.5 + contrast: + weight: 1 + min: 0.5 + max: 1.5 + saturation: + weight: 1 + min: 0.5 + max: 1.5 + hue: + weight: 1 + min: -0.1 + max: 0.1 + sharpness: + weight: 1 + min: 0.9 + max: 1.1