diff --git a/lerobot/scripts/show_image_transforms.py b/lerobot/scripts/show_image_transforms.py index 5dc21c62..49050a58 100644 --- a/lerobot/scripts/show_image_transforms.py +++ b/lerobot/scripts/show_image_transforms.py @@ -1,66 +1,50 @@ from pathlib import Path -import matplotlib.pyplot as plt +from torchvision.transforms import ToPILImage from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.transforms import make_transforms from lerobot.common.utils.utils import init_hydra_config -DEFAULT_CONFIG_PATH = "configs/default.yaml" +DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" +to_pil = ToPILImage() -def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"): +def main(repo_id): """ Apply a series of image transformations to a frame from a dataset and save the transformed images. Args: - cfg (ConfigNode): The configuration object containing the image transformation settings and the dataset to sample. repo_id (str): The ID of the repository. - episode_index (int, optional): The index of the episode to use. Defaults to 0. - output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms". """ - dataset = LeRobotDataset(repo_id) + transforms = ["colorjitter", "sharpness", "blur"] - print(f"Getting frame from camera: {dataset.camera_keys[0]}") + dataset = LeRobotDataset(repo_id, transform=None) + output_dir = Path("outputs/image_transforms") / Path(repo_id) + output_dir.mkdir(parents=True, exist_ok=True) # Get first frame of given episode - from_idx = dataset.episode_data_index["from"][episode_index].item() + from_idx = dataset.episode_data_index["from"][0].item() frame = dataset[from_idx][dataset.camera_keys[0]] + to_pil(frame).save(output_dir / "original_frame.png", quality=100) - Path(output_dir).mkdir(parents=True, exist_ok=True) - base_filename = f"{output_dir}/episode_{episode_index}" - - # Apply each transformation and save the result - for transform in cfg.list: + # Apply each single transformation + for transform_name in transforms: cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ - f"image_transform.list=[{transform}]", "image_transform.enable=True", "image_transform.n_subset=1", - f"image_transform.{transform}_p=1", + f"image_transform.list=[{transform_name}]", + f"image_transform.{transform_name}_p=1", ], ) + transform = make_transforms(cfg.image_transform) + img = transform(frame) + to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) - cfg = cfg.image_transform - t = make_transforms(cfg) - - # Apply transformation to frame - transformed_frame = t(frame) - transformed_frame = transformed_frame.permute(1, 2, 0).numpy() - - # Save transformed frame - plt.imshow(transformed_frame) - plt.savefig(f"{base_filename}_max_transform_{transform}.png") - plt.close() - - frame = frame.permute(1, 2, 0).numpy() - - # Save original frame - plt.imshow(frame) - plt.savefig(f"{base_filename}_original.png") - plt.close() - - print("Saved transformed images.") +if __name__ == "__main__": + repo_id = "cadene/reachy2_teleop_remi" + main(repo_id)