diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 25c68227..eeab633d 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform): return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) -def make_transforms(cfg): +def make_transforms(cfg, to_dtype: torch.dtype = torch.float32): transforms_list = [ v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), @@ -118,4 +118,6 @@ def make_transforms(cfg): transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order ) - return v2.Compose([transforms, v2.ToDtype(torch.float32, scale=True)]) + # return transforms + # return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)]) + return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)]) diff --git a/tests/scripts/save_image_transforms.py b/tests/scripts/save_image_transforms.py index 8cc7d7ba..777dcd96 100644 --- a/tests/scripts/save_image_transforms.py +++ b/tests/scripts/save_image_transforms.py @@ -1,6 +1,8 @@ from pathlib import Path -from torchvision.transforms import ToPILImage, v2 +import torch +from torchvision.transforms import v2 +from safetensors.torch import save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.transforms import RangeRandomSharpness @@ -9,7 +11,7 @@ from lerobot.common.utils.utils import seeded_context DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" ARTIFACT_DIR = "tests/data/save_image_transforms" SEED = 1336 -to_pil = ToPILImage() +to_pil = v2.ToPILImage() def main(repo_id): @@ -30,12 +32,15 @@ def main(repo_id): "sharpness": RangeRandomSharpness(0.0, 2.0), } - # Apply each single transformation + # frames = {"original_frame": original_frame} for name, transform in transforms.items(): with seeded_context(SEED): + # transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)]) transformed_frame = transform(original_frame) + # frames[name] = transform(original_frame) to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100) + # save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors") if __name__ == "__main__": repo_id = "lerobot/aloha_mobile_shrimp" diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6dc91f91..8429997a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,9 +1,17 @@ +from pathlib import Path +import numpy as np +from omegaconf import OmegaConf import pytest import torch from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 +from PIL import Image +from safetensors.torch import load_file -from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness +from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms +from lerobot.common.datasets.utils import flatten_dict +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.utils import DEFAULT_CONFIG_PATH class TestRandomSubsetApply: @@ -76,5 +84,84 @@ class TestRangeRandomSharpness: class TestMakeTransforms: - ... - # TODO + @pytest.fixture(autouse=True) + def setup(self): + """Seed should be the same as the one that was used to generate artifacts""" + self.config = { + "enable": True, + "max_num_transforms": 1, + "random_order": False, + "brightness": { + "weight": 0, + "min": 0.0, + "max": 2.0 + }, + "contrast": { + "weight": 0, + "min": 0.0, + "max": 2.0, + }, + "saturation": { + "weight": 0, + "min": 0.0, + "max": 2.0, + }, + "hue": { + "weight": 0, + "min": -0.5, + "max": 0.5, + }, + "sharpness": { + "weight": 0, + "min": 0.0, + "max": 2.0, + }, + } + self.path = Path("tests/data/save_image_transforms") + # self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors") + self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png") + # self.original_frame = self.expected_frames["original_frame"] + self.transforms = { + "brightness": v2.ColorJitter(brightness=(0.0, 2.0)), + "contrast": v2.ColorJitter(contrast=(0.0, 2.0)), + "saturation": v2.ColorJitter(saturation=(0.0, 2.0)), + "hue": v2.ColorJitter(hue=(-0.5, 0.5)), + "sharpness": RangeRandomSharpness(0.0, 2.0), + } + + @staticmethod + def load_png_to_tensor(path: Path): + return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1) + + @pytest.mark.parametrize( + "transform_key, seed", + [ + ("brightness", 1336), + ("contrast", 1336), + ("saturation", 1336), + ("hue", 1336), + ("sharpness", 1336), + ] + ) + def test_single_transform(self, transform_key, seed): + config = self.config + config[transform_key]["weight"] = 1 + cfg = OmegaConf.create(config) + transform = make_transforms(cfg, to_dtype=torch.uint8) + + # expected_t = self.transforms[transform_key] + with seeded_context(seed): + actual = transform(self.original_frame) + # torch.manual_seed(42) + # actual = actual_t(self.original_frame) + # torch.manual_seed(42) + # expected = expected_t(self.original_frame) + + # with seeded_context(1336): + # expected = expected_t(self.original_frame) + + expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png") + # # expected = self.expected_frames[transform_key] + to_pil = v2.ToPILImage() + to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100) + torch.testing.assert_close(actual, expected)