81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
|
import pytest
|
||
|
import torch
|
||
|
from torchvision.transforms import v2
|
||
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||
|
|
||
|
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness
|
||
|
|
||
|
|
||
|
class TestRandomSubsetApply:
|
||
|
@pytest.fixture(autouse=True)
|
||
|
def setup(self):
|
||
|
self.jitters = [
|
||
|
v2.ColorJitter(brightness=0.5),
|
||
|
v2.ColorJitter(contrast=0.5),
|
||
|
v2.ColorJitter(saturation=0.5),
|
||
|
]
|
||
|
self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||
|
self.img = torch.rand(3, 224, 224)
|
||
|
|
||
|
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||
|
def test_random_choice(self, p):
|
||
|
random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
|
||
|
output = random_choice(self.img)
|
||
|
|
||
|
p_horz, _ = p
|
||
|
if p_horz:
|
||
|
torch.testing.assert_close(output, F.horizontal_flip(self.img))
|
||
|
else:
|
||
|
torch.testing.assert_close(output, F.vertical_flip(self.img))
|
||
|
|
||
|
def test_transform_all(self):
|
||
|
transform = RandomSubsetApply(self.jitters)
|
||
|
output = transform(self.img)
|
||
|
assert output.shape == self.img.shape
|
||
|
|
||
|
def test_transform_subset(self):
|
||
|
transform = RandomSubsetApply(self.jitters, n_subset=2)
|
||
|
output = transform(self.img)
|
||
|
assert output.shape == self.img.shape
|
||
|
|
||
|
def test_random_order(self):
|
||
|
random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||
|
# We can't really check whether the transforms are actually applied in random order. However,
|
||
|
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||
|
# applies them in random order, we can use a fixed order to compute the expected value.
|
||
|
actual = random_order(self.img)
|
||
|
expected = v2.Compose(self.flips)(self.img)
|
||
|
torch.testing.assert_close(actual, expected)
|
||
|
|
||
|
def test_probability_length_mismatch(self):
|
||
|
with pytest.raises(ValueError):
|
||
|
RandomSubsetApply(self.jitters, p=[0.5, 0.5])
|
||
|
|
||
|
def test_invalid_n_subset(self):
|
||
|
with pytest.raises(ValueError):
|
||
|
RandomSubsetApply(self.jitters, n_subset=5)
|
||
|
|
||
|
|
||
|
class TestRangeRandomSharpness:
|
||
|
@pytest.fixture(autouse=True)
|
||
|
def setup(self):
|
||
|
self.img = torch.rand(3, 224, 224)
|
||
|
|
||
|
def test_valid_range(self):
|
||
|
transform = RangeRandomSharpness(0.1, 2.0)
|
||
|
output = transform(self.img)
|
||
|
assert output.shape == self.img.shape
|
||
|
|
||
|
def test_invalid_range_min_negative(self):
|
||
|
with pytest.raises(ValueError):
|
||
|
RangeRandomSharpness(-0.1, 2.0)
|
||
|
|
||
|
def test_invalid_range_max_smaller(self):
|
||
|
with pytest.raises(ValueError):
|
||
|
RangeRandomSharpness(2.0, 0.1)
|
||
|
|
||
|
|
||
|
class TestMakeTransforms:
|
||
|
...
|
||
|
# TODO
|