From 5eea2542d9aafa6000f58cd082f03fac8c5c8551 Mon Sep 17 00:00:00 2001 From: Marina Barannikov Date: Tue, 4 Jun 2024 11:57:45 +0000 Subject: [PATCH] Added visualisations for image augmentation --- lerobot/common/datasets/transforms.py | 10 ++-- lerobot/configs/default.yaml | 9 ++-- lerobot/scripts/show_image_transforms.py | 65 ++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 lerobot/scripts/show_image_transforms.py diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 6e8848ea..a0a84675 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -47,11 +47,11 @@ class RandomSubsetApply(Transform): def make_transforms(cfg): image_transforms = [] - if 'jit' in cfg.image_transform.list: - image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_range, contrast=cfg.colorjitter_range)) - if 'sharpness' in cfg.image_transform.list: + if 'colorjitter' in cfg.list: + image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)) + if 'sharpness' in cfg.list: image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p)) - if 'blur' in cfg.image_transform.list: + if 'blur' in cfg.list: image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p)) - return v2.Compose(RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)) + return v2.Compose([RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index e4cec15e..daf509b8 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -60,10 +60,9 @@ wandb: image_transform: enable: false - colorjittor_range: (0, 1) - colorjittor_p: 1 - # Range from which to sample colorjittor factor - sharpness_factor: 3 + colorjitter_factor: 0.5 + colorjitter_p: 1.O + sharpness_factor: 3.0 # Should be more than 1, setting parameter to 1 does not change the image sharpness_p: 0.5 blur_factor: 0.2 @@ -71,4 +70,4 @@ image_transform: blur_p: 0.5 n_subset: 3 # Maximum number of transforms to apply - list: ["colorjittor", "sharpness", "blur"] + list: ["colorjitter", "sharpness", "blur"] diff --git a/lerobot/scripts/show_image_transforms.py b/lerobot/scripts/show_image_transforms.py new file mode 100644 index 00000000..a2de6304 --- /dev/null +++ b/lerobot/scripts/show_image_transforms.py @@ -0,0 +1,65 @@ + +from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import make_transforms + +from pathlib import Path + +import matplotlib.pyplot as plt + + +DEFAULT_CONFIG_PATH = "configs/default.yaml" + +def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"): + """ + Apply a series of image transformations to a frame from a dataset and save the transformed images. + + Args: + cfg (ConfigNode): The configuration object containing the image transformation settings and the dataset to sample. + repo_id (str): The ID of the repository. + episode_index (int, optional): The index of the episode to use. Defaults to 0. + output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms". + """ + + dataset = LeRobotDataset(repo_id) + + print(f"Getting frame from camera: {dataset.camera_keys[0]}") + + # Get first frame of given episode + from_idx = dataset.episode_data_index["from"][episode_index].item() + frame = dataset[from_idx][dataset.camera_keys[0]] + + Path(output_dir).mkdir(parents=True, exist_ok=True) + base_filename = f"{output_dir}/episode_{episode_index}" + + # Apply each transformation and save the result + for transform in cfg.list: + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + f"image_transform.list=[{transform}]", + "image_transform.enable=True", + "image_transform.n_subset=1", + f"image_transform.{transform}_p=1", + ]) + + cfg = cfg.image_transform + + t = make_transforms(cfg) + + # Apply transformation to frame + transformed_frame = t(frame) + transformed_frame = transformed_frame.permute(1, 2, 0).numpy() + + # Save transformed frame + plt.imshow(transformed_frame) + plt.savefig(f'{base_filename}_max_transform_{transform}.png') + plt.close() + + frame = frame.permute(1, 2, 0).numpy() + # Save original frame + plt.imshow(frame) + plt.savefig(f'{base_filename}_original.png') + plt.close() + + print(f"Saved transformed images.")