2024-06-06 23:23:49 +08:00
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from omegaconf import OmegaConf
|
2024-06-06 00:29:54 +08:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
from torchvision.transforms import v2
|
|
|
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
2024-06-06 23:23:49 +08:00
|
|
|
from PIL import Image
|
|
|
|
from safetensors.torch import load_file
|
2024-06-06 00:29:54 +08:00
|
|
|
|
2024-06-07 00:50:22 +08:00
|
|
|
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
|
2024-06-06 23:23:49 +08:00
|
|
|
from lerobot.common.datasets.utils import flatten_dict
|
|
|
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
|
|
|
from tests.utils import DEFAULT_CONFIG_PATH
|
2024-06-06 00:29:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestRandomSubsetApply:
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def setup(self):
|
|
|
|
self.jitters = [
|
|
|
|
v2.ColorJitter(brightness=0.5),
|
|
|
|
v2.ColorJitter(contrast=0.5),
|
|
|
|
v2.ColorJitter(saturation=0.5),
|
|
|
|
]
|
|
|
|
self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
|
|
|
self.img = torch.rand(3, 224, 224)
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
|
|
|
def test_random_choice(self, p):
|
|
|
|
random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
|
|
|
|
output = random_choice(self.img)
|
|
|
|
|
|
|
|
p_horz, _ = p
|
|
|
|
if p_horz:
|
|
|
|
torch.testing.assert_close(output, F.horizontal_flip(self.img))
|
|
|
|
else:
|
|
|
|
torch.testing.assert_close(output, F.vertical_flip(self.img))
|
|
|
|
|
|
|
|
def test_transform_all(self):
|
|
|
|
transform = RandomSubsetApply(self.jitters)
|
|
|
|
output = transform(self.img)
|
|
|
|
assert output.shape == self.img.shape
|
|
|
|
|
|
|
|
def test_transform_subset(self):
|
|
|
|
transform = RandomSubsetApply(self.jitters, n_subset=2)
|
|
|
|
output = transform(self.img)
|
|
|
|
assert output.shape == self.img.shape
|
|
|
|
|
|
|
|
def test_random_order(self):
|
|
|
|
random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
|
|
|
# We can't really check whether the transforms are actually applied in random order. However,
|
|
|
|
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
|
|
|
# applies them in random order, we can use a fixed order to compute the expected value.
|
|
|
|
actual = random_order(self.img)
|
|
|
|
expected = v2.Compose(self.flips)(self.img)
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
|
|
def test_probability_length_mismatch(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
RandomSubsetApply(self.jitters, p=[0.5, 0.5])
|
|
|
|
|
|
|
|
def test_invalid_n_subset(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
RandomSubsetApply(self.jitters, n_subset=5)
|
|
|
|
|
|
|
|
|
|
|
|
class TestRangeRandomSharpness:
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def setup(self):
|
|
|
|
self.img = torch.rand(3, 224, 224)
|
|
|
|
|
|
|
|
def test_valid_range(self):
|
|
|
|
transform = RangeRandomSharpness(0.1, 2.0)
|
|
|
|
output = transform(self.img)
|
|
|
|
assert output.shape == self.img.shape
|
|
|
|
|
|
|
|
def test_invalid_range_min_negative(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
RangeRandomSharpness(-0.1, 2.0)
|
|
|
|
|
|
|
|
def test_invalid_range_max_smaller(self):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
RangeRandomSharpness(2.0, 0.1)
|
|
|
|
|
|
|
|
|
|
|
|
class TestMakeTransforms:
|
2024-06-06 23:23:49 +08:00
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def setup(self):
|
|
|
|
"""Seed should be the same as the one that was used to generate artifacts"""
|
|
|
|
self.config = {
|
|
|
|
"enable": True,
|
|
|
|
"max_num_transforms": 1,
|
|
|
|
"random_order": False,
|
|
|
|
"brightness": {
|
|
|
|
"weight": 0,
|
|
|
|
"min": 0.0,
|
|
|
|
"max": 2.0
|
|
|
|
},
|
|
|
|
"contrast": {
|
|
|
|
"weight": 0,
|
|
|
|
"min": 0.0,
|
|
|
|
"max": 2.0,
|
|
|
|
},
|
|
|
|
"saturation": {
|
|
|
|
"weight": 0,
|
|
|
|
"min": 0.0,
|
|
|
|
"max": 2.0,
|
|
|
|
},
|
|
|
|
"hue": {
|
|
|
|
"weight": 0,
|
|
|
|
"min": -0.5,
|
|
|
|
"max": 0.5,
|
|
|
|
},
|
|
|
|
"sharpness": {
|
|
|
|
"weight": 0,
|
|
|
|
"min": 0.0,
|
|
|
|
"max": 2.0,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
self.path = Path("tests/data/save_image_transforms")
|
|
|
|
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
|
|
|
|
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
|
|
|
# self.original_frame = self.expected_frames["original_frame"]
|
|
|
|
self.transforms = {
|
|
|
|
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
|
|
|
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
|
|
|
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
|
|
|
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
|
|
|
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
|
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_png_to_tensor(path: Path):
|
|
|
|
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"transform_key, seed",
|
|
|
|
[
|
|
|
|
("brightness", 1336),
|
|
|
|
("contrast", 1336),
|
|
|
|
("saturation", 1336),
|
|
|
|
("hue", 1336),
|
|
|
|
("sharpness", 1336),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
def test_single_transform(self, transform_key, seed):
|
|
|
|
config = self.config
|
|
|
|
config[transform_key]["weight"] = 1
|
|
|
|
cfg = OmegaConf.create(config)
|
2024-06-07 00:50:22 +08:00
|
|
|
transform = make_image_transforms(cfg, to_dtype=torch.uint8)
|
2024-06-06 23:23:49 +08:00
|
|
|
|
|
|
|
# expected_t = self.transforms[transform_key]
|
|
|
|
with seeded_context(seed):
|
|
|
|
actual = transform(self.original_frame)
|
|
|
|
# torch.manual_seed(42)
|
|
|
|
# actual = actual_t(self.original_frame)
|
|
|
|
# torch.manual_seed(42)
|
|
|
|
# expected = expected_t(self.original_frame)
|
|
|
|
|
|
|
|
# with seeded_context(1336):
|
|
|
|
# expected = expected_t(self.original_frame)
|
|
|
|
|
|
|
|
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
|
|
|
|
# # expected = self.expected_frames[transform_key]
|
|
|
|
to_pil = v2.ToPILImage()
|
|
|
|
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
|
|
|
|
torch.testing.assert_close(actual, expected)
|