Remove dataset from image_transform tests
This commit is contained in:
parent
fee5fa5c2e
commit
ee51f54cb5
|
@ -23,7 +23,6 @@ from safetensors.torch import load_file
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||||
|
@ -38,14 +37,14 @@ def load_png_to_tensor(path: Path):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def img():
|
def img_tensor() -> torch.Tensor:
|
||||||
dataset = LeRobotDataset(DATASET_REPO_ID)
|
return torch.rand((3, 480, 640), dtype=torch.float32)
|
||||||
return dataset[0][dataset.camera_keys[0]]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def img_random():
|
def img() -> Image:
|
||||||
return torch.rand(3, 480, 640)
|
img_array = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8)
|
||||||
|
return Image.fromarray(img_array)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -67,47 +66,47 @@ def default_transforms():
|
||||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_no_transform(img):
|
def test_get_image_transforms_no_transform(img_tensor):
|
||||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||||
torch.testing.assert_close(tf_actual(img), img)
|
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
def test_get_image_transforms_brightness(img, min_max):
|
def test_get_image_transforms_brightness(img_tensor, min_max):
|
||||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
def test_get_image_transforms_contrast(img, min_max):
|
def test_get_image_transforms_contrast(img_tensor, min_max):
|
||||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
def test_get_image_transforms_saturation(img, min_max):
|
def test_get_image_transforms_saturation(img_tensor, min_max):
|
||||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||||
def test_get_image_transforms_hue(img, min_max):
|
def test_get_image_transforms_hue(img_tensor, min_max):
|
||||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||||
tf_expected = v2.ColorJitter(hue=min_max)
|
tf_expected = v2.ColorJitter(hue=min_max)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||||
def test_get_image_transforms_sharpness(img, min_max):
|
def test_get_image_transforms_sharpness(img_tensor, min_max):
|
||||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_max_num_transforms(img):
|
def test_get_image_transforms_max_num_transforms(img_tensor):
|
||||||
tf_actual = get_image_transforms(
|
tf_actual = get_image_transforms(
|
||||||
brightness_min_max=(0.5, 0.5),
|
brightness_min_max=(0.5, 0.5),
|
||||||
contrast_min_max=(0.5, 0.5),
|
contrast_min_max=(0.5, 0.5),
|
||||||
|
@ -125,11 +124,11 @@ def test_get_image_transforms_max_num_transforms(img):
|
||||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
@require_x86_64_kernel
|
@require_x86_64_kernel
|
||||||
def test_get_image_transforms_random_order(img):
|
def test_get_image_transforms_random_order(img_tensor):
|
||||||
out_imgs = []
|
out_imgs = []
|
||||||
tf = get_image_transforms(
|
tf = get_image_transforms(
|
||||||
brightness_min_max=(0.5, 0.5),
|
brightness_min_max=(0.5, 0.5),
|
||||||
|
@ -141,7 +140,7 @@ def test_get_image_transforms_random_order(img):
|
||||||
)
|
)
|
||||||
with seeded_context(1337):
|
with seeded_context(1337):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
out_imgs.append(tf(img))
|
out_imgs.append(tf(img_tensor))
|
||||||
|
|
||||||
for i in range(1, len(out_imgs)):
|
for i in range(1, len(out_imgs)):
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
@ -158,21 +157,21 @@ def test_get_image_transforms_random_order(img):
|
||||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
def test_backward_compatibility_torchvision(transform, min_max_values, img_tensor, single_transforms):
|
||||||
for min_max in min_max_values:
|
for min_max in min_max_values:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
f"{transform}_weight": 1.0,
|
f"{transform}_weight": 1.0,
|
||||||
f"{transform}_min_max": min_max,
|
f"{transform}_min_max": min_max,
|
||||||
}
|
}
|
||||||
tf = get_image_transforms(**kwargs)
|
tf = get_image_transforms(**kwargs)
|
||||||
actual = tf(img)
|
actual = tf(img_tensor)
|
||||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||||
expected = single_transforms[key]
|
expected = single_transforms[key]
|
||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
@require_x86_64_kernel
|
@require_x86_64_kernel
|
||||||
def test_backward_compatibility_default_config(img, default_transforms):
|
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||||
cfg_tf = cfg.training.image_transforms
|
cfg_tf = cfg.training.image_transforms
|
||||||
default_tf = get_image_transforms(
|
default_tf = get_image_transforms(
|
||||||
|
@ -191,7 +190,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||||
)
|
)
|
||||||
|
|
||||||
with seeded_context(1337):
|
with seeded_context(1337):
|
||||||
actual = default_tf(img)
|
actual = default_tf(img_tensor)
|
||||||
|
|
||||||
expected = default_transforms["default"]
|
expected = default_transforms["default"]
|
||||||
|
|
||||||
|
@ -199,33 +198,33 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||||
def test_random_subset_apply_single_choice(p, img):
|
def test_random_subset_apply_single_choice(p, img_tensor):
|
||||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||||
actual = random_choice(img)
|
actual = random_choice(img_tensor)
|
||||||
|
|
||||||
p_horz, _ = p
|
p_horz, _ = p
|
||||||
if p_horz:
|
if p_horz:
|
||||||
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
|
||||||
else:
|
else:
|
||||||
torch.testing.assert_close(actual, F.vertical_flip(img))
|
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_random_order(img):
|
def test_random_subset_apply_random_order(img_tensor):
|
||||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||||
# We can't really check whether the transforms are actually applied in random order. However,
|
# We can't really check whether the transforms are actually applied in random order. However,
|
||||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||||
actual = random_order(img)
|
actual = random_order(img_tensor)
|
||||||
expected = v2.Compose(flips)(img)
|
expected = v2.Compose(flips)(img_tensor)
|
||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
def test_random_subset_apply_valid_transforms(color_jitters, img_tensor):
|
||||||
transform = RandomSubsetApply(color_jitters)
|
transform = RandomSubsetApply(color_jitters)
|
||||||
output = transform(img)
|
output = transform(img_tensor)
|
||||||
assert output.shape == img.shape
|
assert output.shape == img_tensor.shape
|
||||||
|
|
||||||
|
|
||||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||||
|
@ -239,16 +238,16 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_valid_range_tuple(img):
|
def test_sharpness_jitter_valid_range_tuple(img_tensor):
|
||||||
tf = SharpnessJitter((0.1, 2.0))
|
tf = SharpnessJitter((0.1, 2.0))
|
||||||
output = tf(img)
|
output = tf(img_tensor)
|
||||||
assert output.shape == img.shape
|
assert output.shape == img_tensor.shape
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_valid_range_float(img):
|
def test_sharpness_jitter_valid_range_float(img_tensor):
|
||||||
tf = SharpnessJitter(0.5)
|
tf = SharpnessJitter(0.5)
|
||||||
output = tf(img)
|
output = tf(img_tensor)
|
||||||
assert output.shape == img.shape
|
assert output.shape == img_tensor.shape
|
||||||
|
|
||||||
|
|
||||||
def test_sharpness_jitter_invalid_range_min_negative():
|
def test_sharpness_jitter_invalid_range_min_negative():
|
||||||
|
|
Loading…
Reference in New Issue