diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index bae0677e..ee6c5715 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,7 +19,7 @@ import torch from omegaconf import ListConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset -from lerobot.common.datasets.transforms import make_image_transforms +from lerobot.common.datasets.transforms import get_image_transforms def resolve_delta_timestamps(cfg): @@ -72,7 +72,22 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None + image_transforms = None + if cfg.image_transforms.enable: + image_transforms = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 78282a1f..9cc40768 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -98,26 +98,60 @@ class RangeRandomSharpness(Transform): return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) -def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32): - transforms_list = [ - v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), - v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), - v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)), - v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)), - RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max), - ] - transforms_weights = [ - cfg.brightness.weight, - cfg.contrast.weight, - cfg.saturation.weight, - cfg.hue.weight, - cfg.sharpness.weight, - ] +def get_image_transforms( + brightness_weight: float = 1.0, + brightness_min_max: tuple[float, float] | None = None, + contrast_weight: float = 1.0, + contrast_min_max: tuple[float, float] | None = None, + saturation_weight: float = 1.0, + saturation_min_max: tuple[float, float] | None = None, + hue_weight: float = 1.0, + hue_min_max: tuple[float, float] | None = None, + sharpness_weight: float = 1.0, + sharpness_min_max: tuple[float, float] | None = None, + max_num_transforms: int | None = None, + random_order: bool = False, + ): + + def check_value_error(name, weight, min_max): + if min_max is not None: + if len(min_max) != 2: + raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.") + if weight < 0.: + raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).") - transforms = RandomSubsetApply( - transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order + check_value_error("brightness", brightness_weight, brightness_min_max) + check_value_error("contrast", contrast_weight, contrast_min_max) + check_value_error("saturation", saturation_weight, saturation_min_max) + check_value_error("hue", hue_weight, hue_min_max) + check_value_error("sharpness", sharpness_weight, sharpness_min_max) + + weights = [] + transforms = [] + if brightness_min_max is not None: + weights.append(brightness_weight) + transforms.append(v2.ColorJitter(brightness=brightness_min_max)) + if contrast_min_max is not None: + weights.append(contrast_weight) + transforms.append(v2.ColorJitter(contrast=contrast_min_max)) + if saturation_min_max is not None: + weights.append(saturation_weight) + transforms.append(v2.ColorJitter(saturation=saturation_min_max)) + if hue_min_max is not None: + weights.append(hue_weight) + transforms.append(v2.ColorJitter(hue=hue_min_max)) + if sharpness_min_max is not None: + weights.append(sharpness_weight) + transforms.append(RangeRandomSharpness(**sharpness_min_max)) + + if max_num_transforms is None: + n_subset = len(transforms) + else: + n_subset = min(len(transforms), max_num_transforms) + + final_transforms = RandomSubsetApply( + transforms, p=weights, n_subset=n_subset, random_order=random_order ) - # return transforms - # return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)]) - return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)]) + # TODO(rcadene, aliberts): add v2.ToDtype float16? + return final_transforms diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index a5c58e1f..096380d8 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -67,7 +67,7 @@ image_transforms: # weight: This represents the multinomial probability (with no replacement) # used for sampling the transform. If the sum of the weights is not 1, # they will be normalized. - # min/max: Lower & upper bound respectively used for sampling the transform's parameter + # min_max: Lower & upper bound respectively used for sampling the transform's parameter # (following uniform distribution) when it's applied. enable: false # This is the number of transforms (sampled from these below) that will be applied to each frame. @@ -78,21 +78,16 @@ image_transforms: random_order: false brightness: weight: 1 - min: 0.8 - max: 1.2 + min_max: [0.8, 1.2] contrast: weight: 1 - min: 0.8 - max: 1.2 + min_max: [0.8, 1.2] saturation: weight: 1 - min: 0.5 - max: 1.5 + min_max: [0.5, 1.5] hue: weight: 1 - min: -0.05 - max: 0.05 + min_max: [-0.05, 0.05] sharpness: weight: 1 - min: 0.8 - max: 1.2 + min_max: [0.8, 1.2] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index def5ff2a..32d89ae4 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,5 +1,6 @@ from pathlib import Path +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset import numpy as np import pytest import torch @@ -8,144 +9,245 @@ from PIL import Image from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 -from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms +from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms from lerobot.common.utils.utils import seeded_context -class TestRandomSubsetApply: - @pytest.fixture(autouse=True) - def setup(self): - self.jitters = [ - v2.ColorJitter(brightness=0.5), - v2.ColorJitter(contrast=0.5), - v2.ColorJitter(saturation=0.5), - ] - self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] - self.img = torch.rand(3, 224, 224) +# test_make_image_transforms +# - - @pytest.mark.parametrize("p", [[0, 1], [1, 0]]) - def test_random_choice(self, p): - random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False) - output = random_choice(self.img) +# test backward compatibility torchvision +# - save artifacts - p_horz, _ = p - if p_horz: - torch.testing.assert_close(output, F.horizontal_flip(self.img)) - else: - torch.testing.assert_close(output, F.vertical_flip(self.img)) - - def test_transform_all(self): - transform = RandomSubsetApply(self.jitters) - output = transform(self.img) - assert output.shape == self.img.shape - - def test_transform_subset(self): - transform = RandomSubsetApply(self.jitters, n_subset=2) - output = transform(self.img) - assert output.shape == self.img.shape - - def test_random_order(self): - random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True) - # We can't really check whether the transforms are actually applied in random order. However, - # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform - # applies them in random order, we can use a fixed order to compute the expected value. - actual = random_order(self.img) - expected = v2.Compose(self.flips)(self.img) - torch.testing.assert_close(actual, expected) - - def test_probability_length_mismatch(self): - with pytest.raises(ValueError): - RandomSubsetApply(self.jitters, p=[0.5, 0.5]) - - def test_invalid_n_subset(self): - with pytest.raises(ValueError): - RandomSubsetApply(self.jitters, n_subset=5) +# test backward compatibility default yaml (enable false, enable true) +# - save artifacts -class TestRangeRandomSharpness: - @pytest.fixture(autouse=True) - def setup(self): - self.img = torch.rand(3, 224, 224) - - def test_valid_range(self): - transform = RangeRandomSharpness(0.1, 2.0) - output = transform(self.img) - assert output.shape == self.img.shape - - def test_invalid_range_min_negative(self): - with pytest.raises(ValueError): - RangeRandomSharpness(-0.1, 2.0) - - def test_invalid_range_max_smaller(self): - with pytest.raises(ValueError): - RangeRandomSharpness(2.0, 0.1) +def test_get_image_transforms_no_transform(): + get_image_transforms() + get_image_transforms(sharpness_weight=0.0) + get_image_transforms(max_num_transforms=0) -class TestMakeImageTransforms: - @pytest.fixture(autouse=True) - def setup(self): - """Seed should be the same as the one that was used to generate artifacts""" - self.config = { - "enable": True, - "max_num_transforms": 1, - "random_order": False, - "brightness": {"weight": 0, "min": 2.0, "max": 2.0}, - "contrast": { - "weight": 0, - "min": 2.0, - "max": 2.0, - }, - "saturation": { - "weight": 0, - "min": 2.0, - "max": 2.0, - }, - "hue": { - "weight": 0, - "min": 0.5, - "max": 0.5, - }, - "sharpness": { - "weight": 0, - "min": 2.0, - "max": 2.0, - }, - } - self.path = Path("tests/data/save_image_transforms") - self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png") - self.transforms = { - "brightness": v2.ColorJitter(brightness=(2.0, 2.0)), - "contrast": v2.ColorJitter(contrast=(2.0, 2.0)), - "saturation": v2.ColorJitter(saturation=(2.0, 2.0)), - "hue": v2.ColorJitter(hue=(0.5, 0.5)), - "sharpness": RangeRandomSharpness(2.0, 2.0), - } +@pytest.fixture +def img(): + # dataset = LeRobotDataset("lerobot/pusht") + # item = dataset[0] + # return item["observation.image"] + path = "tests/data/save_image_transforms/original_frame.png" + img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) + return img_chw - @staticmethod - def load_png_to_tensor(path: Path): - return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) +def test_get_image_transforms_brightness(img): + brightness_min_max = (0.5, 0.5) + tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max) + tf_expected = v2.ColorJitter(brightness=brightness_min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) - @pytest.mark.parametrize( - "transform_key, seed", - [ - ("brightness", 1336), - ("contrast", 1336), - ("saturation", 1336), - ("hue", 1336), - ("sharpness", 1336), - ], +def test_get_image_transforms_contrast(img): + contrast_min_max = (0.5, 0.5) + tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max) + tf_expected = v2.ColorJitter(contrast=contrast_min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + +def test_get_image_transforms_saturation(img): + saturation_min_max = (0.5, 0.5) + tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max) + tf_expected = v2.ColorJitter(saturation=saturation_min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + +def test_get_image_transforms_hue(img): + hue_min_max = (0.5, 0.5) + tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max) + tf_expected = v2.ColorJitter(hue=hue_min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + +def test_get_image_transforms_sharpness(img): + sharpness_min_max = (0.5, 0.5) + tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max) + tf_expected = RangeRandomSharpness(**sharpness_min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + +def test_get_image_transforms_max_num_transforms(img): + tf_actual = get_image_transforms( + saturation_min_max=(0.5, 0.5), + constrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=False, ) - def test_single_transform(self, transform_key, seed): - config = self.config - config[transform_key]["weight"] = 1 - cfg = OmegaConf.create(config) + tf_expected = v2.Compose([ + v2.ColorJitter(brightness=(0.5, 0.5)), + v2.ColorJitter(contrast=(0.5, 0.5)), + v2.ColorJitter(saturation=(0.5, 0.5)), + v2.ColorJitter(hue=(0.5, 0.5)), + RangeRandomSharpness(sharpness=(0.5, 0.5)), + ]) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) - actual_t = make_image_transforms(cfg, to_dtype=torch.uint8) - with seeded_context(1336): - actual = actual_t(self.original_frame) - expected_t = self.transforms[transform_key] - with seeded_context(1336): - expected = expected_t(self.original_frame) +def test_get_image_transforms_random_order(img): + out_imgs = [] + with seeded_context(1337): + for _ in range(20): + tf = get_image_transforms( + saturation_min_max=(0.5, 0.5), + constrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=False, + ) + out_imgs.append(tf(img)) + + for i in range(1,10): + with pytest.raises(ValueError): + torch.testing.assert_close(out_imgs[0], out_imgs[i]) - torch.testing.assert_close(actual, expected) + + +def test_backward_compatibility_torchvision(): + pass + +def test_backward_compatibility_default_yaml(): + pass + + +# class TestRandomSubsetApply: +# @pytest.fixture(autouse=True) +# def setup(self): +# self.jitters = [ +# v2.ColorJitter(brightness=0.5), +# v2.ColorJitter(contrast=0.5), +# v2.ColorJitter(saturation=0.5), +# ] +# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] +# self.img = torch.rand(3, 224, 224) + +# @pytest.mark.parametrize("p", [[0, 1], [1, 0]]) +# def test_random_choice(self, p): +# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False) +# output = random_choice(self.img) + +# p_horz, _ = p +# if p_horz: +# torch.testing.assert_close(output, F.horizontal_flip(self.img)) +# else: +# torch.testing.assert_close(output, F.vertical_flip(self.img)) + +# def test_transform_all(self): +# transform = RandomSubsetApply(self.jitters) +# output = transform(self.img) +# assert output.shape == self.img.shape + +# def test_transform_subset(self): +# transform = RandomSubsetApply(self.jitters, n_subset=2) +# output = transform(self.img) +# assert output.shape == self.img.shape + +# def test_random_order(self): +# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True) +# # We can't really check whether the transforms are actually applied in random order. However, +# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform +# # applies them in random order, we can use a fixed order to compute the expected value. +# actual = random_order(self.img) +# expected = v2.Compose(self.flips)(self.img) +# torch.testing.assert_close(actual, expected) + +# def test_probability_length_mismatch(self): +# with pytest.raises(ValueError): +# RandomSubsetApply(self.jitters, p=[0.5, 0.5]) + +# def test_invalid_n_subset(self): +# with pytest.raises(ValueError): +# RandomSubsetApply(self.jitters, n_subset=5) + + +# class TestRangeRandomSharpness: +# @pytest.fixture(autouse=True) +# def setup(self): +# self.img = torch.rand(3, 224, 224) + +# def test_valid_range(self): +# transform = RangeRandomSharpness(0.1, 2.0) +# output = transform(self.img) +# assert output.shape == self.img.shape + +# def test_invalid_range_min_negative(self): +# with pytest.raises(ValueError): +# RangeRandomSharpness(-0.1, 2.0) + +# def test_invalid_range_max_smaller(self): +# with pytest.raises(ValueError): +# RangeRandomSharpness(2.0, 0.1) + + +# class TestMakeImageTransforms: +# @pytest.fixture(autouse=True) +# def setup(self): +# """Seed should be the same as the one that was used to generate artifacts""" +# self.config = { +# "enable": True, +# "max_num_transforms": 1, +# "random_order": False, +# "brightness": {"weight": 0, "min": 2.0, "max": 2.0}, +# "contrast": { +# "weight": 0, +# "min": 2.0, +# "max": 2.0, +# }, +# "saturation": { +# "weight": 0, +# "min": 2.0, +# "max": 2.0, +# }, +# "hue": { +# "weight": 0, +# "min": 0.5, +# "max": 0.5, +# }, +# "sharpness": { +# "weight": 0, +# "min": 2.0, +# "max": 2.0, +# }, +# } +# self.path = Path("tests/data/save_image_transforms") +# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png") +# self.transforms = { +# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)), +# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)), +# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)), +# "hue": v2.ColorJitter(hue=(0.5, 0.5)), +# "sharpness": RangeRandomSharpness(2.0, 2.0), +# } + +# @staticmethod +# def load_png_to_tensor(path: Path): +# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) + +# @pytest.mark.parametrize( +# "transform_key, seed", +# [ +# ("brightness", 1336), +# ("contrast", 1336), +# ("saturation", 1336), +# ("hue", 1336), +# ("sharpness", 1336), +# ], +# ) +# def test_single_transform(self, transform_key, seed): +# config = self.config +# config[transform_key]["weight"] = 1 +# cfg = OmegaConf.create(config) + +# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8) +# with seeded_context(1336): +# actual = actual_t(self.original_frame) + +# expected_t = self.transforms[transform_key] +# with seeded_context(1336): +# expected = expected_t(self.original_frame) + +# torch.testing.assert_close(actual, expected)