diff --git a/tests/data/save_image_transforms/1336_brightness.png b/tests/data/save_image_transforms/1336_brightness.png new file mode 100644 index 00000000..d6503e16 Binary files /dev/null and b/tests/data/save_image_transforms/1336_brightness.png differ diff --git a/tests/data/save_image_transforms/1336_contrast.png b/tests/data/save_image_transforms/1336_contrast.png new file mode 100644 index 00000000..759a4cae Binary files /dev/null and b/tests/data/save_image_transforms/1336_contrast.png differ diff --git a/tests/data/save_image_transforms/1336_hue.png b/tests/data/save_image_transforms/1336_hue.png new file mode 100644 index 00000000..45420663 Binary files /dev/null and b/tests/data/save_image_transforms/1336_hue.png differ diff --git a/tests/data/save_image_transforms/1336_saturation.png b/tests/data/save_image_transforms/1336_saturation.png new file mode 100644 index 00000000..eb3b7d8a Binary files /dev/null and b/tests/data/save_image_transforms/1336_saturation.png differ diff --git a/tests/data/save_image_transforms/1336_sharpness.png b/tests/data/save_image_transforms/1336_sharpness.png new file mode 100644 index 00000000..af11e14e Binary files /dev/null and b/tests/data/save_image_transforms/1336_sharpness.png differ diff --git a/tests/data/save_image_transforms/original_frame.png b/tests/data/save_image_transforms/original_frame.png new file mode 100644 index 00000000..53297856 Binary files /dev/null and b/tests/data/save_image_transforms/original_frame.png differ diff --git a/tests/scripts/save_image_transforms.py b/tests/scripts/save_image_transforms.py new file mode 100644 index 00000000..8cc7d7ba --- /dev/null +++ b/tests/scripts/save_image_transforms.py @@ -0,0 +1,42 @@ +from pathlib import Path + +from torchvision.transforms import ToPILImage, v2 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import RangeRandomSharpness +from lerobot.common.utils.utils import seeded_context + +DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" +ARTIFACT_DIR = "tests/data/save_image_transforms" +SEED = 1336 +to_pil = ToPILImage() + + +def main(repo_id): + dataset = LeRobotDataset(repo_id, transform=None) + output_dir = Path(ARTIFACT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get first frame of given episode + from_idx = dataset.episode_data_index["from"][0].item() + original_frame = dataset[from_idx][dataset.camera_keys[0]] + to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) + + transforms = { + "brightness": v2.ColorJitter(brightness=(0.0, 2.0)), + "contrast": v2.ColorJitter(contrast=(0.0, 2.0)), + "saturation": v2.ColorJitter(saturation=(0.0, 2.0)), + "hue": v2.ColorJitter(hue=(-0.5, 0.5)), + "sharpness": RangeRandomSharpness(0.0, 2.0), + } + + # Apply each single transformation + for name, transform in transforms.items(): + with seeded_context(SEED): + transformed_frame = transform(original_frame) + to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100) + + +if __name__ == "__main__": + repo_id = "lerobot/aloha_mobile_shrimp" + main(repo_id)