lerobot/tests/test_transforms.py

254 lines
8.9 KiB
Python

from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import pytest
import torch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms
from lerobot.common.utils.utils import seeded_context
# test_make_image_transforms
# -
# test backward compatibility torchvision
# - save artifacts
# test backward compatibility default yaml (enable false, enable true)
# - save artifacts
def test_get_image_transforms_no_transform():
get_image_transforms()
get_image_transforms(sharpness_weight=0.0)
get_image_transforms(max_num_transforms=0)
@pytest.fixture
def img():
# dataset = LeRobotDataset("lerobot/pusht")
# item = dataset[0]
# return item["observation.image"]
path = "tests/data/save_image_transforms/original_frame.png"
img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
return img_chw
def test_get_image_transforms_brightness(img):
brightness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max)
tf_expected = v2.ColorJitter(brightness=brightness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_contrast(img):
contrast_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max)
tf_expected = v2.ColorJitter(contrast=contrast_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_saturation(img):
saturation_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max)
tf_expected = v2.ColorJitter(saturation=saturation_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_hue(img):
hue_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max)
tf_expected = v2.ColorJitter(hue=hue_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_sharpness(img):
sharpness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max)
tf_expected = RangeRandomSharpness(**sharpness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_max_num_transforms(img):
tf_actual = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
tf_expected = v2.Compose([
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
RangeRandomSharpness(sharpness=(0.5, 0.5)),
])
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_random_order(img):
out_imgs = []
with seeded_context(1337):
for _ in range(20):
tf = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
out_imgs.append(tf(img))
for i in range(1,10):
with pytest.raises(ValueError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
def test_backward_compatibility_torchvision():
pass
def test_backward_compatibility_default_yaml():
pass
# class TestRandomSubsetApply:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.jitters = [
# v2.ColorJitter(brightness=0.5),
# v2.ColorJitter(contrast=0.5),
# v2.ColorJitter(saturation=0.5),
# ]
# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
# self.img = torch.rand(3, 224, 224)
# @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
# def test_random_choice(self, p):
# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
# output = random_choice(self.img)
# p_horz, _ = p
# if p_horz:
# torch.testing.assert_close(output, F.horizontal_flip(self.img))
# else:
# torch.testing.assert_close(output, F.vertical_flip(self.img))
# def test_transform_all(self):
# transform = RandomSubsetApply(self.jitters)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_transform_subset(self):
# transform = RandomSubsetApply(self.jitters, n_subset=2)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_random_order(self):
# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# # We can't really check whether the transforms are actually applied in random order. However,
# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# # applies them in random order, we can use a fixed order to compute the expected value.
# actual = random_order(self.img)
# expected = v2.Compose(self.flips)(self.img)
# torch.testing.assert_close(actual, expected)
# def test_probability_length_mismatch(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, p=[0.5, 0.5])
# def test_invalid_n_subset(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, n_subset=5)
# class TestRangeRandomSharpness:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.img = torch.rand(3, 224, 224)
# def test_valid_range(self):
# transform = RangeRandomSharpness(0.1, 2.0)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_invalid_range_min_negative(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(-0.1, 2.0)
# def test_invalid_range_max_smaller(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(2.0, 0.1)
# class TestMakeImageTransforms:
# @pytest.fixture(autouse=True)
# def setup(self):
# """Seed should be the same as the one that was used to generate artifacts"""
# self.config = {
# "enable": True,
# "max_num_transforms": 1,
# "random_order": False,
# "brightness": {"weight": 0, "min": 2.0, "max": 2.0},
# "contrast": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "saturation": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "hue": {
# "weight": 0,
# "min": 0.5,
# "max": 0.5,
# },
# "sharpness": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# }
# self.path = Path("tests/data/save_image_transforms")
# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.transforms = {
# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
# "hue": v2.ColorJitter(hue=(0.5, 0.5)),
# "sharpness": RangeRandomSharpness(2.0, 2.0),
# }
# @staticmethod
# def load_png_to_tensor(path: Path):
# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
# @pytest.mark.parametrize(
# "transform_key, seed",
# [
# ("brightness", 1336),
# ("contrast", 1336),
# ("saturation", 1336),
# ("hue", 1336),
# ("sharpness", 1336),
# ],
# )
# def test_single_transform(self, transform_key, seed):
# config = self.config
# config[transform_key]["weight"] = 1
# cfg = OmegaConf.create(config)
# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
# with seeded_context(1336):
# actual = actual_t(self.original_frame)
# expected_t = self.transforms[transform_key]
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
# torch.testing.assert_close(actual, expected)