fix TestMakeImageTransforms

This commit is contained in:
Simon Alibert 2024-06-07 17:23:54 +02:00
parent b60810a8b6
commit e52942a200
6 changed files with 23 additions and 39 deletions

Binary file not shown.

Before

(image error) Size: 118 KiB

Binary file not shown.

Before

(image error) Size: 118 KiB

Binary file not shown.

Before

(image error) Size: 187 KiB

Binary file not shown.

Before

(image error) Size: 168 KiB

Binary file not shown.

Before

(image error) Size: 224 KiB

View File

@ -1,17 +1,15 @@
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
import pytest
import torch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH
from lerobot.common.utils.utils import seeded_context
class TestRandomSubsetApply:
@ -83,7 +81,7 @@ class TestRangeRandomSharpness:
RangeRandomSharpness(2.0, 0.1)
class TestMakeTransforms:
class TestMakeImageTransforms:
@pytest.fixture(autouse=True)
def setup(self):
"""Seed should be the same as the one that was used to generate artifacts"""
@ -91,47 +89,41 @@ class TestMakeTransforms:
"enable": True,
"max_num_transforms": 1,
"random_order": False,
"brightness": {
"weight": 0,
"min": 0.0,
"max": 2.0
},
"brightness": {"weight": 0, "min": 2.0, "max": 2.0},
"contrast": {
"weight": 0,
"min": 0.0,
"min": 2.0,
"max": 2.0,
},
"saturation": {
"weight": 0,
"min": 0.0,
"min": 2.0,
"max": 2.0,
},
"hue": {
"weight": 0,
"min": -0.5,
"min": 0.5,
"max": 0.5,
},
"sharpness": {
"weight": 0,
"min": 0.0,
"min": 2.0,
"max": 2.0,
},
}
self.path = Path("tests/data/save_image_transforms")
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.original_frame = self.expected_frames["original_frame"]
self.transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": RangeRandomSharpness(0.0, 2.0),
"brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
"hue": v2.ColorJitter(hue=(0.5, 0.5)),
"sharpness": RangeRandomSharpness(2.0, 2.0),
}
@staticmethod
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.mark.parametrize(
"transform_key, seed",
@ -141,27 +133,19 @@ class TestMakeTransforms:
("saturation", 1336),
("hue", 1336),
("sharpness", 1336),
]
],
)
def test_single_transform(self, transform_key, seed):
config = self.config
config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config)
transform = make_image_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key]
with seeded_context(seed):
actual = transform(self.original_frame)
# torch.manual_seed(42)
# actual = actual_t(self.original_frame)
# torch.manual_seed(42)
# expected = expected_t(self.original_frame)
actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
with seeded_context(1336):
actual = actual_t(self.original_frame)
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
expected_t = self.transforms[transform_key]
with seeded_context(1336):
expected = expected_t(self.original_frame)
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
# # expected = self.expected_frames[transform_key]
to_pil = v2.ToPILImage()
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
torch.testing.assert_close(actual, expected)