fix TestMakeImageTransforms
This commit is contained in:
parent
b60810a8b6
commit
e52942a200
tests
data/save_image_transforms
test_transforms.py
Binary file not shown.
Before ![]() (image error) Size: 118 KiB |
Binary file not shown.
Before ![]() (image error) Size: 118 KiB |
Binary file not shown.
Before ![]() (image error) Size: 187 KiB |
Binary file not shown.
Before ![]() (image error) Size: 168 KiB |
Binary file not shown.
Before ![]() (image error) Size: 224 KiB |
|
@ -1,17 +1,15 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
import pytest
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
|
||||
class TestRandomSubsetApply:
|
||||
|
@ -83,7 +81,7 @@ class TestRangeRandomSharpness:
|
|||
RangeRandomSharpness(2.0, 0.1)
|
||||
|
||||
|
||||
class TestMakeTransforms:
|
||||
class TestMakeImageTransforms:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Seed should be the same as the one that was used to generate artifacts"""
|
||||
|
@ -91,47 +89,41 @@ class TestMakeTransforms:
|
|||
"enable": True,
|
||||
"max_num_transforms": 1,
|
||||
"random_order": False,
|
||||
"brightness": {
|
||||
"weight": 0,
|
||||
"min": 0.0,
|
||||
"max": 2.0
|
||||
},
|
||||
"brightness": {"weight": 0, "min": 2.0, "max": 2.0},
|
||||
"contrast": {
|
||||
"weight": 0,
|
||||
"min": 0.0,
|
||||
"min": 2.0,
|
||||
"max": 2.0,
|
||||
},
|
||||
"saturation": {
|
||||
"weight": 0,
|
||||
"min": 0.0,
|
||||
"min": 2.0,
|
||||
"max": 2.0,
|
||||
},
|
||||
"hue": {
|
||||
"weight": 0,
|
||||
"min": -0.5,
|
||||
"min": 0.5,
|
||||
"max": 0.5,
|
||||
},
|
||||
"sharpness": {
|
||||
"weight": 0,
|
||||
"min": 0.0,
|
||||
"min": 2.0,
|
||||
"max": 2.0,
|
||||
},
|
||||
}
|
||||
self.path = Path("tests/data/save_image_transforms")
|
||||
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
|
||||
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
||||
# self.original_frame = self.expected_frames["original_frame"]
|
||||
self.transforms = {
|
||||
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
||||
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
||||
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
||||
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
||||
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
||||
"brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
|
||||
"contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
|
||||
"saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
|
||||
"hue": v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
"sharpness": RangeRandomSharpness(2.0, 2.0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transform_key, seed",
|
||||
|
@ -141,27 +133,19 @@ class TestMakeTransforms:
|
|||
("saturation", 1336),
|
||||
("hue", 1336),
|
||||
("sharpness", 1336),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_single_transform(self, transform_key, seed):
|
||||
config = self.config
|
||||
config[transform_key]["weight"] = 1
|
||||
cfg = OmegaConf.create(config)
|
||||
transform = make_image_transforms(cfg, to_dtype=torch.uint8)
|
||||
|
||||
# expected_t = self.transforms[transform_key]
|
||||
with seeded_context(seed):
|
||||
actual = transform(self.original_frame)
|
||||
# torch.manual_seed(42)
|
||||
# actual = actual_t(self.original_frame)
|
||||
# torch.manual_seed(42)
|
||||
# expected = expected_t(self.original_frame)
|
||||
actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
|
||||
with seeded_context(1336):
|
||||
actual = actual_t(self.original_frame)
|
||||
|
||||
# with seeded_context(1336):
|
||||
# expected = expected_t(self.original_frame)
|
||||
expected_t = self.transforms[transform_key]
|
||||
with seeded_context(1336):
|
||||
expected = expected_t(self.original_frame)
|
||||
|
||||
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
|
||||
# # expected = self.expected_frames[transform_key]
|
||||
to_pil = v2.ToPILImage()
|
||||
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
|
Loading…
Reference in New Issue