Updated show_transform to match config

This commit is contained in:
Marina Barannikov 2024-06-05 12:00:59 +00:00
parent ceb95592af
commit 4dbc1adb0d
1 changed files with 16 additions and 9 deletions

View File

@ -24,21 +24,28 @@ def main(repo_id):
output_dir = Path("outputs/image_transforms") / Path(repo_id) output_dir = Path("outputs/image_transforms") / Path(repo_id)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of given episode # Get first frame of 1st episode
from_idx = dataset.episode_data_index["from"][0].item() first_idx = dataset.episode_data_index["from"][0].item()
frame = dataset[from_idx][dataset.camera_keys[0]] frame = dataset[first_idx][dataset.camera_keys[0]]
to_pil(frame).save(output_dir / "original_frame.png", quality=100) to_pil(frame).save(output_dir / "original_frame.png", quality=100)
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
# Apply each single transformation # Apply each single transformation
for transform_name in transforms: for transform_name in transforms:
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides = [ overrides = [
"image_transform.enable=True", "image_transform.enable=True",
"image_transform.n_subset=1", "image_transform.max_num_transforms=1",
f"image_transform.list=[{transform_name}]", ]
f"image_transform.{transform_name}_p=1", for t in transforms:
], if t == transform_name:
overrides.append(f"image_transform.{t}.weight=1")
overrides.append(f"image_transform.{t}_p=1")
else:
overrides.append(f"image_transform.{t}.weight=0")
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=overrides,
) )
transform = make_transforms(cfg.image_transform) transform = make_transforms(cfg.image_transform)
img = transform(frame) img = transform(frame)