added visualization for min and max transforms (#271)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
This commit is contained in:
parent
a92d79fff2
commit
e28fa2344c
|
@ -65,11 +65,10 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import get_image_transforms
|
from lerobot.common.datasets.transforms import get_image_transforms
|
||||||
|
|
||||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||||
N_EXAMPLES = 5
|
|
||||||
to_pil = ToPILImage()
|
to_pil = ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
def save_config_all_transforms(cfg, original_frame, output_dir):
|
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
|
||||||
tf = get_image_transforms(
|
tf = get_image_transforms(
|
||||||
brightness_weight=cfg.brightness.weight,
|
brightness_weight=cfg.brightness.weight,
|
||||||
brightness_min_max=cfg.brightness.min_max,
|
brightness_min_max=cfg.brightness.min_max,
|
||||||
|
@ -88,7 +87,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||||
output_dir_all = output_dir / "all"
|
output_dir_all = output_dir / "all"
|
||||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i in range(1, N_EXAMPLES + 1):
|
for i in range(1, n_examples + 1):
|
||||||
transformed_frame = tf(original_frame)
|
transformed_frame = tf(original_frame)
|
||||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||||
print(f" {output_dir_all}")
|
print(f" {output_dir_all}")
|
||||||
|
|
||||||
|
|
||||||
def save_config_single_transforms(cfg, original_frame, output_dir):
|
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
|
||||||
transforms = [
|
transforms = [
|
||||||
"brightness",
|
"brightness",
|
||||||
"contrast",
|
"contrast",
|
||||||
|
@ -106,6 +105,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||||
]
|
]
|
||||||
print("Individual transforms examples saved to:")
|
print("Individual transforms examples saved to:")
|
||||||
for transform in transforms:
|
for transform in transforms:
|
||||||
|
# Apply one transformation with random value in min_max range
|
||||||
kwargs = {
|
kwargs = {
|
||||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||||
|
@ -114,18 +114,46 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||||
output_dir_single = output_dir / f"{transform}"
|
output_dir_single = output_dir / f"{transform}"
|
||||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i in range(1, N_EXAMPLES + 1):
|
for i in range(1, n_examples + 1):
|
||||||
transformed_frame = tf(original_frame)
|
transformed_frame = tf(original_frame)
|
||||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||||
|
|
||||||
|
# Apply min transformation
|
||||||
|
min_value, max_value = cfg[f"{transform}"].min_max
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||||
|
f"{transform}_min_max": (min_value, min_value),
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
transformed_frame = tf(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)
|
||||||
|
|
||||||
|
# Apply max transformation
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||||
|
f"{transform}_min_max": (max_value, max_value),
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
transformed_frame = tf(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)
|
||||||
|
|
||||||
|
# Apply mean transformation
|
||||||
|
mean_value = (min_value + max_value) / 2
|
||||||
|
kwargs = {
|
||||||
|
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||||
|
f"{transform}_min_max": (mean_value, mean_value),
|
||||||
|
}
|
||||||
|
tf = get_image_transforms(**kwargs)
|
||||||
|
transformed_frame = tf(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)
|
||||||
|
|
||||||
print(f" {output_dir_single}")
|
print(f" {output_dir_single}")
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||||
def visualize_transforms(cfg):
|
|
||||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||||
|
|
||||||
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
|
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Get 1st frame from 1st camera of 1st episode
|
# Get 1st frame from 1st camera of 1st episode
|
||||||
|
@ -134,8 +162,13 @@ def visualize_transforms(cfg):
|
||||||
print("\nOriginal frame saved to:")
|
print("\nOriginal frame saved to:")
|
||||||
print(f" {output_dir / 'original_frame.png'}.")
|
print(f" {output_dir / 'original_frame.png'}.")
|
||||||
|
|
||||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
def visualize_transforms_cli(cfg):
|
||||||
|
visualize_transforms(cfg, output_dir=OUTPUT_DIR)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -26,6 +26,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
|
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||||
|
|
||||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||||
|
@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative():
|
||||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
SharpnessJitter((2.0, 0.1))
|
SharpnessJitter((2.0, 0.1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"repo_id, n_examples",
|
||||||
|
[
|
||||||
|
("lerobot/aloha_sim_transfer_cube_human", 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_visualize_image_transforms(repo_id, n_examples):
|
||||||
|
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||||
|
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||||
|
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||||
|
output_dir = output_dir / repo_id.split("/")[-1]
|
||||||
|
|
||||||
|
# Check if the original frame image exists
|
||||||
|
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||||
|
|
||||||
|
# Check if the transformed images exist for each transform type
|
||||||
|
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||||
|
for transform in transforms:
|
||||||
|
transform_dir = output_dir / transform
|
||||||
|
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||||
|
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||||
|
|
||||||
|
# Check for specific files within each transform directory
|
||||||
|
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||||
|
for file_name in expected_files:
|
||||||
|
assert (
|
||||||
|
transform_dir / file_name
|
||||||
|
).exists(), f"{file_name} was not found in {transform} directory."
|
||||||
|
|
||||||
|
# Check if the combined transforms directory exists and contains the right files
|
||||||
|
combined_transforms_dir = output_dir / "all"
|
||||||
|
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||||
|
assert any(
|
||||||
|
combined_transforms_dir.iterdir()
|
||||||
|
), "No transformed images found in combined transforms directory."
|
||||||
|
for i in range(1, n_examples + 1):
|
||||||
|
assert (
|
||||||
|
combined_transforms_dir / f"{i}.png"
|
||||||
|
).exists(), f"Combined transform image {i}.png was not found."
|
||||||
|
|
Loading…
Reference in New Issue