diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index a0a84675..e4ba980c 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -47,11 +47,15 @@ class RandomSubsetApply(Transform): def make_transforms(cfg): image_transforms = [] - if 'colorjitter' in cfg.list: - image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)) - if 'sharpness' in cfg.list: + if "colorjitter" in cfg.list: + image_transforms.append( + v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor) + ) + if "sharpness" in cfg.list: image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p)) - if 'blur' in cfg.list: + if "blur" in cfg.list: image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p)) - return v2.Compose([RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]) + return v2.Compose( + [RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)] + ) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index daf509b8..04c687bd 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -64,10 +64,10 @@ image_transform: colorjitter_p: 1.O sharpness_factor: 3.0 # Should be more than 1, setting parameter to 1 does not change the image - sharpness_p: 0.5 + sharpness_p: 0.5 blur_factor: 0.2 # Should be less than 1, setting parameter to 1 does not change the image - blur_p: 0.5 + blur_p: 0.5 n_subset: 3 # Maximum number of transforms to apply list: ["colorjitter", "sharpness", "blur"] diff --git a/lerobot/scripts/show_image_transforms.py b/lerobot/scripts/show_image_transforms.py index a2de6304..5dc21c62 100644 --- a/lerobot/scripts/show_image_transforms.py +++ b/lerobot/scripts/show_image_transforms.py @@ -1,15 +1,14 @@ - -from lerobot.common.utils.utils import init_hydra_config -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.transforms import make_transforms - from pathlib import Path import matplotlib.pyplot as plt +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import make_transforms +from lerobot.common.utils.utils import init_hydra_config DEFAULT_CONFIG_PATH = "configs/default.yaml" + def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"): """ Apply a series of image transformations to a frame from a dataset and save the transformed images. @@ -20,7 +19,7 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho episode_index (int, optional): The index of the episode to use. Defaults to 0. output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms". """ - + dataset = LeRobotDataset(repo_id) print(f"Getting frame from camera: {dataset.camera_keys[0]}") @@ -41,25 +40,27 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho "image_transform.enable=True", "image_transform.n_subset=1", f"image_transform.{transform}_p=1", - ]) - + ], + ) + cfg = cfg.image_transform t = make_transforms(cfg) - + # Apply transformation to frame transformed_frame = t(frame) transformed_frame = transformed_frame.permute(1, 2, 0).numpy() # Save transformed frame plt.imshow(transformed_frame) - plt.savefig(f'{base_filename}_max_transform_{transform}.png') + plt.savefig(f"{base_filename}_max_transform_{transform}.png") plt.close() frame = frame.permute(1, 2, 0).numpy() + # Save original frame plt.imshow(frame) - plt.savefig(f'{base_filename}_original.png') + plt.savefig(f"{base_filename}_original.png") plt.close() - print(f"Saved transformed images.") + print("Saved transformed images.")