diff --git a/lerobot/scripts/show_image_transforms.py b/lerobot/scripts/show_image_transforms.py index e0db3377..a833117b 100644 --- a/lerobot/scripts/show_image_transforms.py +++ b/lerobot/scripts/show_image_transforms.py @@ -1,7 +1,6 @@ from pathlib import Path import hydra - from torchvision.transforms import ToPILImage from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -21,9 +20,9 @@ def main(cfg, output_dir=Path("outputs/image_transforms")): Returns: None """ - + dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None) - + output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir.mkdir(parents=True, exist_ok=True) @@ -36,14 +35,14 @@ def main(cfg, output_dir=Path("outputs/image_transforms")): # Apply each single transformation for transform_name in transforms: - cfg.image_transform.enable=True - cfg.image_transform.max_num_transforms=1 + cfg.image_transform.enable = True + cfg.image_transform.max_num_transforms = 1 for t in transforms: if t == transform_name: - cfg.image_transform[t].weight=1 + cfg.image_transform[t].weight = 1 else: - cfg.image_transform[t].weight=0 + cfg.image_transform[t].weight = 0 transform = make_transforms(cfg.image_transform) img = transform(frame) @@ -56,5 +55,6 @@ def visualize_transforms_cli(cfg: dict): cfg, ) + if __name__ == "__main__": visualize_transforms_cli()