Added visualisations for image augmentation
This commit is contained in:
parent
42f9cc9c2a
commit
5eea2542d9
|
@ -47,11 +47,11 @@ class RandomSubsetApply(Transform):
|
||||||
|
|
||||||
def make_transforms(cfg):
|
def make_transforms(cfg):
|
||||||
image_transforms = []
|
image_transforms = []
|
||||||
if 'jit' in cfg.image_transform.list:
|
if 'colorjitter' in cfg.list:
|
||||||
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_range, contrast=cfg.colorjitter_range))
|
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor))
|
||||||
if 'sharpness' in cfg.image_transform.list:
|
if 'sharpness' in cfg.list:
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
||||||
if 'blur' in cfg.image_transform.list:
|
if 'blur' in cfg.list:
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
||||||
|
|
||||||
return v2.Compose(RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True))
|
return v2.Compose([RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)])
|
||||||
|
|
|
@ -60,10 +60,9 @@ wandb:
|
||||||
|
|
||||||
image_transform:
|
image_transform:
|
||||||
enable: false
|
enable: false
|
||||||
colorjittor_range: (0, 1)
|
colorjitter_factor: 0.5
|
||||||
colorjittor_p: 1
|
colorjitter_p: 1.O
|
||||||
# Range from which to sample colorjittor factor
|
sharpness_factor: 3.0
|
||||||
sharpness_factor: 3
|
|
||||||
# Should be more than 1, setting parameter to 1 does not change the image
|
# Should be more than 1, setting parameter to 1 does not change the image
|
||||||
sharpness_p: 0.5
|
sharpness_p: 0.5
|
||||||
blur_factor: 0.2
|
blur_factor: 0.2
|
||||||
|
@ -71,4 +70,4 @@ image_transform:
|
||||||
blur_p: 0.5
|
blur_p: 0.5
|
||||||
n_subset: 3
|
n_subset: 3
|
||||||
# Maximum number of transforms to apply
|
# Maximum number of transforms to apply
|
||||||
list: ["colorjittor", "sharpness", "blur"]
|
list: ["colorjitter", "sharpness", "blur"]
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import make_transforms
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_PATH = "configs/default.yaml"
|
||||||
|
|
||||||
|
def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"):
|
||||||
|
"""
|
||||||
|
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (ConfigNode): The configuration object containing the image transformation settings and the dataset to sample.
|
||||||
|
repo_id (str): The ID of the repository.
|
||||||
|
episode_index (int, optional): The index of the episode to use. Defaults to 0.
|
||||||
|
output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms".
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(repo_id)
|
||||||
|
|
||||||
|
print(f"Getting frame from camera: {dataset.camera_keys[0]}")
|
||||||
|
|
||||||
|
# Get first frame of given episode
|
||||||
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
|
frame = dataset[from_idx][dataset.camera_keys[0]]
|
||||||
|
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
base_filename = f"{output_dir}/episode_{episode_index}"
|
||||||
|
|
||||||
|
# Apply each transformation and save the result
|
||||||
|
for transform in cfg.list:
|
||||||
|
cfg = init_hydra_config(
|
||||||
|
DEFAULT_CONFIG_PATH,
|
||||||
|
overrides=[
|
||||||
|
f"image_transform.list=[{transform}]",
|
||||||
|
"image_transform.enable=True",
|
||||||
|
"image_transform.n_subset=1",
|
||||||
|
f"image_transform.{transform}_p=1",
|
||||||
|
])
|
||||||
|
|
||||||
|
cfg = cfg.image_transform
|
||||||
|
|
||||||
|
t = make_transforms(cfg)
|
||||||
|
|
||||||
|
# Apply transformation to frame
|
||||||
|
transformed_frame = t(frame)
|
||||||
|
transformed_frame = transformed_frame.permute(1, 2, 0).numpy()
|
||||||
|
|
||||||
|
# Save transformed frame
|
||||||
|
plt.imshow(transformed_frame)
|
||||||
|
plt.savefig(f'{base_filename}_max_transform_{transform}.png')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
frame = frame.permute(1, 2, 0).numpy()
|
||||||
|
# Save original frame
|
||||||
|
plt.imshow(frame)
|
||||||
|
plt.savefig(f'{base_filename}_original.png')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"Saved transformed images.")
|
Loading…
Reference in New Issue