diff --git a/examples/6_show_image_transforms.py b/examples/6_show_image_transforms.py new file mode 100644 index 00000000..956381ec --- /dev/null +++ b/examples/6_show_image_transforms.py @@ -0,0 +1,46 @@ +from pathlib import Path + +from torchvision.transforms import ToPILImage, v2 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +to_pil = ToPILImage() + +# Create a directory to store the training checkpoint. +output_dir = Path("outputs/image_transforms") +output_dir.mkdir(parents=True, exist_ok=True) + +repo_id = "lerobot/aloha_static_tape" + +# Create a LeRobotDataset with no transformations +dataset = LeRobotDataset(repo_id, transform=None) + +# Get the index of the first frame in the first episode +first_idx = dataset.episode_data_index["from"][0].item() + +# Get the frame from the first camera +frame = dataset[first_idx][dataset.camera_keys[0]] + +# Save the original frame +to_pil(frame).save(output_dir / "original_frame.png", quality=100) +print(f"Original frame saved to {output_dir / 'original_frame.png'}.") + + +# Define the transformations +transforms = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 1.5)), + v2.ColorJitter(contrast=(0.5, 1.5)), + v2.RandomAdjustSharpness(sharpness_factor=2, p=1), + ] +) + +# Create a LeRobotDataset with the defined transformations +transformed_dataset = LeRobotDataset(repo_id, transform=transforms) + +# Get a frame from the transformed dataset +transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] + +# Save the transformed frame +to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) +print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")