Updated formatting

This commit is contained in:
Marina Barannikov 2024-06-04 12:06:36 +00:00
parent 31e3c82386
commit 22bd1f0669
3 changed files with 24 additions and 19 deletions

View File

@ -47,11 +47,15 @@ class RandomSubsetApply(Transform):
def make_transforms(cfg): def make_transforms(cfg):
image_transforms = [] image_transforms = []
if 'colorjitter' in cfg.list: if "colorjitter" in cfg.list:
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)) image_transforms.append(
if 'sharpness' in cfg.list: v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
)
if "sharpness" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p)) image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
if 'blur' in cfg.list: if "blur" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p)) image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
return v2.Compose([RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]) return v2.Compose(
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
)

View File

@ -64,10 +64,10 @@ image_transform:
colorjitter_p: 1.O colorjitter_p: 1.O
sharpness_factor: 3.0 sharpness_factor: 3.0
# Should be more than 1, setting parameter to 1 does not change the image # Should be more than 1, setting parameter to 1 does not change the image
sharpness_p: 0.5 sharpness_p: 0.5
blur_factor: 0.2 blur_factor: 0.2
# Should be less than 1, setting parameter to 1 does not change the image # Should be less than 1, setting parameter to 1 does not change the image
blur_p: 0.5 blur_p: 0.5
n_subset: 3 n_subset: 3
# Maximum number of transforms to apply # Maximum number of transforms to apply
list: ["colorjitter", "sharpness", "blur"] list: ["colorjitter", "sharpness", "blur"]

View File

@ -1,15 +1,14 @@
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
from lerobot.common.utils.utils import init_hydra_config
DEFAULT_CONFIG_PATH = "configs/default.yaml" DEFAULT_CONFIG_PATH = "configs/default.yaml"
def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"): def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"):
""" """
Apply a series of image transformations to a frame from a dataset and save the transformed images. Apply a series of image transformations to a frame from a dataset and save the transformed images.
@ -20,7 +19,7 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho
episode_index (int, optional): The index of the episode to use. Defaults to 0. episode_index (int, optional): The index of the episode to use. Defaults to 0.
output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms". output_dir (str, optional): The directory to save the transformed images. Defaults to "outputs/show_image_transforms".
""" """
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
print(f"Getting frame from camera: {dataset.camera_keys[0]}") print(f"Getting frame from camera: {dataset.camera_keys[0]}")
@ -41,25 +40,27 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho
"image_transform.enable=True", "image_transform.enable=True",
"image_transform.n_subset=1", "image_transform.n_subset=1",
f"image_transform.{transform}_p=1", f"image_transform.{transform}_p=1",
]) ],
)
cfg = cfg.image_transform cfg = cfg.image_transform
t = make_transforms(cfg) t = make_transforms(cfg)
# Apply transformation to frame # Apply transformation to frame
transformed_frame = t(frame) transformed_frame = t(frame)
transformed_frame = transformed_frame.permute(1, 2, 0).numpy() transformed_frame = transformed_frame.permute(1, 2, 0).numpy()
# Save transformed frame # Save transformed frame
plt.imshow(transformed_frame) plt.imshow(transformed_frame)
plt.savefig(f'{base_filename}_max_transform_{transform}.png') plt.savefig(f"{base_filename}_max_transform_{transform}.png")
plt.close() plt.close()
frame = frame.permute(1, 2, 0).numpy() frame = frame.permute(1, 2, 0).numpy()
# Save original frame # Save original frame
plt.imshow(frame) plt.imshow(frame)
plt.savefig(f'{base_filename}_original.png') plt.savefig(f"{base_filename}_original.png")
plt.close() plt.close()
print(f"Saved transformed images.") print("Saved transformed images.")