Added example of torchvision image augmentation on LeRobotDataset

This commit is contained in:
Marina Barannikov 2024-06-05 10:53:18 +00:00
parent fdf56e7a62
commit a544949ebe
1 changed files with 46 additions and 0 deletions

View File

@ -0,0 +1,46 @@
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
to_pil = ToPILImage()
# Create a directory to store the training checkpoint.
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, transform=None)
# Get the index of the first frame in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame from the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
# Save the original frame
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
# Define the transformations
transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create a LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, transform=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
# Save the transformed frame
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")