Updated formatting
This commit is contained in:
parent
31e3c82386
commit
22bd1f0669
|
@ -47,11 +47,15 @@ class RandomSubsetApply(Transform):
|
||||||
|
|
||||||
def make_transforms(cfg):
|
def make_transforms(cfg):
|
||||||
image_transforms = []
|
image_transforms = []
|
||||||
if 'colorjitter' in cfg.list:
|
if "colorjitter" in cfg.list:
|
||||||
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor))
|
image_transforms.append(
|
||||||
if 'sharpness' in cfg.list:
|
v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
|
||||||
|
)
|
||||||
|
if "sharpness" in cfg.list:
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
||||||
if 'blur' in cfg.list:
|
if "blur" in cfg.list:
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
||||||
|
|
||||||
return v2.Compose([RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)])
|
return v2.Compose(
|
||||||
|
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
|
||||||
|
)
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.datasets.transforms import make_transforms
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import make_transforms
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = "configs/default.yaml"
|
DEFAULT_CONFIG_PATH = "configs/default.yaml"
|
||||||
|
|
||||||
|
|
||||||
def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"):
|
def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/show_image_transforms"):
|
||||||
"""
|
"""
|
||||||
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
||||||
|
@ -41,7 +40,8 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho
|
||||||
"image_transform.enable=True",
|
"image_transform.enable=True",
|
||||||
"image_transform.n_subset=1",
|
"image_transform.n_subset=1",
|
||||||
f"image_transform.{transform}_p=1",
|
f"image_transform.{transform}_p=1",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cfg = cfg.image_transform
|
cfg = cfg.image_transform
|
||||||
|
|
||||||
|
@ -53,13 +53,14 @@ def show_image_transforms(cfg, repo_id, episode_index=0, output_dir="outputs/sho
|
||||||
|
|
||||||
# Save transformed frame
|
# Save transformed frame
|
||||||
plt.imshow(transformed_frame)
|
plt.imshow(transformed_frame)
|
||||||
plt.savefig(f'{base_filename}_max_transform_{transform}.png')
|
plt.savefig(f"{base_filename}_max_transform_{transform}.png")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
frame = frame.permute(1, 2, 0).numpy()
|
frame = frame.permute(1, 2, 0).numpy()
|
||||||
|
|
||||||
# Save original frame
|
# Save original frame
|
||||||
plt.imshow(frame)
|
plt.imshow(frame)
|
||||||
plt.savefig(f'{base_filename}_original.png')
|
plt.savefig(f"{base_filename}_original.png")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
print(f"Saved transformed images.")
|
print("Saved transformed images.")
|
||||||
|
|
Loading…
Reference in New Issue