refactor show_image_transforms
This commit is contained in:
parent
22bd1f0669
commit
443b06b412
|
@ -1,66 +1,50 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
from torchvision.transforms import ToPILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import make_transforms
|
from lerobot.common.datasets.transforms import make_transforms
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = "configs/default.yaml"
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
|
to_pil = ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"):
|
def main(repo_id):
|
||||||
"""
|
"""
|
||||||
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (ConfigNode): The configuration object containing the image transformation settings and the dataset to sample.
|
|
||||||
repo_id (str): The ID of the repository.
|
repo_id (str): The ID of the repository.
|
||||||
episode_index (int, optional): The index of the episode to use. Defaults to 0.
|
|
||||||
output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms".
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id)
|
transforms = ["colorjitter", "sharpness", "blur"]
|
||||||
|
|
||||||
print(f"Getting frame from camera: {dataset.camera_keys[0]}")
|
dataset = LeRobotDataset(repo_id, transform=None)
|
||||||
|
output_dir = Path("outputs/image_transforms") / Path(repo_id)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Get first frame of given episode
|
# Get first frame of given episode
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.episode_data_index["from"][0].item()
|
||||||
frame = dataset[from_idx][dataset.camera_keys[0]]
|
frame = dataset[from_idx][dataset.camera_keys[0]]
|
||||||
|
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||||
|
|
||||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
# Apply each single transformation
|
||||||
base_filename = f"{output_dir}/episode_{episode_index}"
|
for transform_name in transforms:
|
||||||
|
|
||||||
# Apply each transformation and save the result
|
|
||||||
for transform in cfg.list:
|
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[
|
overrides=[
|
||||||
f"image_transform.list=[{transform}]",
|
|
||||||
"image_transform.enable=True",
|
"image_transform.enable=True",
|
||||||
"image_transform.n_subset=1",
|
"image_transform.n_subset=1",
|
||||||
f"image_transform.{transform}_p=1",
|
f"image_transform.list=[{transform_name}]",
|
||||||
|
f"image_transform.{transform_name}_p=1",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
transform = make_transforms(cfg.image_transform)
|
||||||
|
img = transform(frame)
|
||||||
|
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
|
||||||
|
|
||||||
cfg = cfg.image_transform
|
|
||||||
|
|
||||||
t = make_transforms(cfg)
|
if __name__ == "__main__":
|
||||||
|
repo_id = "cadene/reachy2_teleop_remi"
|
||||||
# Apply transformation to frame
|
main(repo_id)
|
||||||
transformed_frame = t(frame)
|
|
||||||
transformed_frame = transformed_frame.permute(1, 2, 0).numpy()
|
|
||||||
|
|
||||||
# Save transformed frame
|
|
||||||
plt.imshow(transformed_frame)
|
|
||||||
plt.savefig(f"{base_filename}_max_transform_{transform}.png")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
frame = frame.permute(1, 2, 0).numpy()
|
|
||||||
|
|
||||||
# Save original frame
|
|
||||||
plt.imshow(frame)
|
|
||||||
plt.savefig(f"{base_filename}_original.png")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
print("Saved transformed images.")
|
|
||||||
|
|
Loading…
Reference in New Issue