lerobot/lerobot/scripts/visualize_image_transforms.py

131 lines
4.9 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize effects of image transforms for a given configuration.
This script will generate examples of transformed images as they are output by LeRobot dataset.
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
Example:
```bash
python lerobot/scripts/visualize_image_transforms.py \
--repo_id=lerobot/pusht \
--episodes='[0]' \
--image_transforms.enable=True
```
"""
import logging
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
import draccus
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import (
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.configs.default import DatasetConfig
OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage()
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
tfs = ImageTransforms(cfg)
for i in range(1, n_examples + 1):
transformed_frame = tfs(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
print("Combined transforms examples saved to:")
print(f" {output_dir_all}")
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
if not cfg.enable:
logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
)
return
print("Individual transforms examples saved to:")
for tf_name, tf_cfg in cfg.tfs.items():
# Apply a few transformation with random value in min_max range
output_dir_single = output_dir / tf_name
output_dir_single.mkdir(parents=True, exist_ok=True)
tf = make_transform_from_config(tf_cfg)
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
# Apply min, max, average transformations
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
for key, (min_, max_) in tf_cfg.kwargs.items():
avg = (min_ + max_) / 2
tf_cfg_kwgs_min[key] = [min_, min_]
tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg]
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame)
tf_frame_avg = tf_avg(original_frame)
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
print(f" {output_dir_single}")
@draccus.wrap()
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
revision=cfg.revision,
video_backend=cfg.video_backend,
)
output_dir = output_dir / cfg.repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.meta.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
if __name__ == "__main__":
visualize_image_transforms()