diff --git a/tests/data/save_image_transforms/1336_brightness.png b/tests/data/save_image_transforms/1336_brightness.png deleted file mode 100644 index d6503e16..00000000 Binary files a/tests/data/save_image_transforms/1336_brightness.png and /dev/null differ diff --git a/tests/data/save_image_transforms/1336_contrast.png b/tests/data/save_image_transforms/1336_contrast.png deleted file mode 100644 index 759a4cae..00000000 Binary files a/tests/data/save_image_transforms/1336_contrast.png and /dev/null differ diff --git a/tests/data/save_image_transforms/1336_hue.png b/tests/data/save_image_transforms/1336_hue.png deleted file mode 100644 index 45420663..00000000 Binary files a/tests/data/save_image_transforms/1336_hue.png and /dev/null differ diff --git a/tests/data/save_image_transforms/1336_saturation.png b/tests/data/save_image_transforms/1336_saturation.png deleted file mode 100644 index eb3b7d8a..00000000 Binary files a/tests/data/save_image_transforms/1336_saturation.png and /dev/null differ diff --git a/tests/data/save_image_transforms/1336_sharpness.png b/tests/data/save_image_transforms/1336_sharpness.png deleted file mode 100644 index af11e14e..00000000 Binary files a/tests/data/save_image_transforms/1336_sharpness.png and /dev/null differ diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 16f04f9a..def5ff2a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,17 +1,15 @@ from pathlib import Path + import numpy as np -from omegaconf import OmegaConf import pytest import torch +from omegaconf import OmegaConf +from PIL import Image from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 -from PIL import Image -from safetensors.torch import load_file from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms -from lerobot.common.datasets.utils import flatten_dict -from lerobot.common.utils.utils import init_hydra_config, seeded_context -from tests.utils import DEFAULT_CONFIG_PATH +from lerobot.common.utils.utils import seeded_context class TestRandomSubsetApply: @@ -83,7 +81,7 @@ class TestRangeRandomSharpness: RangeRandomSharpness(2.0, 0.1) -class TestMakeTransforms: +class TestMakeImageTransforms: @pytest.fixture(autouse=True) def setup(self): """Seed should be the same as the one that was used to generate artifacts""" @@ -91,47 +89,41 @@ class TestMakeTransforms: "enable": True, "max_num_transforms": 1, "random_order": False, - "brightness": { - "weight": 0, - "min": 0.0, - "max": 2.0 - }, + "brightness": {"weight": 0, "min": 2.0, "max": 2.0}, "contrast": { "weight": 0, - "min": 0.0, + "min": 2.0, "max": 2.0, }, "saturation": { "weight": 0, - "min": 0.0, + "min": 2.0, "max": 2.0, }, "hue": { "weight": 0, - "min": -0.5, + "min": 0.5, "max": 0.5, }, "sharpness": { "weight": 0, - "min": 0.0, + "min": 2.0, "max": 2.0, }, } self.path = Path("tests/data/save_image_transforms") - # self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors") self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png") - # self.original_frame = self.expected_frames["original_frame"] self.transforms = { - "brightness": v2.ColorJitter(brightness=(0.0, 2.0)), - "contrast": v2.ColorJitter(contrast=(0.0, 2.0)), - "saturation": v2.ColorJitter(saturation=(0.0, 2.0)), - "hue": v2.ColorJitter(hue=(-0.5, 0.5)), - "sharpness": RangeRandomSharpness(0.0, 2.0), + "brightness": v2.ColorJitter(brightness=(2.0, 2.0)), + "contrast": v2.ColorJitter(contrast=(2.0, 2.0)), + "saturation": v2.ColorJitter(saturation=(2.0, 2.0)), + "hue": v2.ColorJitter(hue=(0.5, 0.5)), + "sharpness": RangeRandomSharpness(2.0, 2.0), } @staticmethod def load_png_to_tensor(path: Path): - return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1) + return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) @pytest.mark.parametrize( "transform_key, seed", @@ -141,27 +133,19 @@ class TestMakeTransforms: ("saturation", 1336), ("hue", 1336), ("sharpness", 1336), - ] + ], ) def test_single_transform(self, transform_key, seed): config = self.config config[transform_key]["weight"] = 1 cfg = OmegaConf.create(config) - transform = make_image_transforms(cfg, to_dtype=torch.uint8) - # expected_t = self.transforms[transform_key] - with seeded_context(seed): - actual = transform(self.original_frame) - # torch.manual_seed(42) - # actual = actual_t(self.original_frame) - # torch.manual_seed(42) - # expected = expected_t(self.original_frame) + actual_t = make_image_transforms(cfg, to_dtype=torch.uint8) + with seeded_context(1336): + actual = actual_t(self.original_frame) - # with seeded_context(1336): - # expected = expected_t(self.original_frame) + expected_t = self.transforms[transform_key] + with seeded_context(1336): + expected = expected_t(self.original_frame) - expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png") - # # expected = self.expected_frames[transform_key] - to_pil = v2.ToPILImage() - to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100) torch.testing.assert_close(actual, expected)