From ee51f54cb51ed037acdee870f659753c822cc96a Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 29 Oct 2024 16:08:01 +0100 Subject: [PATCH] Remove dataset from image_transform tests --- tests/test_image_transforms.py | 83 +++++++++++++++++----------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index ccc40ddf..9f3a62ca 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -23,7 +23,6 @@ from safetensors.torch import load_file from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.scripts.visualize_image_transforms import visualize_transforms @@ -38,14 +37,14 @@ def load_png_to_tensor(path: Path): @pytest.fixture -def img(): - dataset = LeRobotDataset(DATASET_REPO_ID) - return dataset[0][dataset.camera_keys[0]] +def img_tensor() -> torch.Tensor: + return torch.rand((3, 480, 640), dtype=torch.float32) @pytest.fixture -def img_random(): - return torch.rand(3, 480, 640) +def img() -> Image: + img_array = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8) + return Image.fromarray(img_array) @pytest.fixture @@ -67,47 +66,47 @@ def default_transforms(): return load_file(ARTIFACT_DIR / "default_transforms.safetensors") -def test_get_image_transforms_no_transform(img): +def test_get_image_transforms_no_transform(img_tensor): tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0) - torch.testing.assert_close(tf_actual(img), img) + torch.testing.assert_close(tf_actual(img_tensor), img_tensor) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_brightness(img, min_max): +def test_get_image_transforms_brightness(img_tensor, min_max): tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max) tf_expected = v2.ColorJitter(brightness=min_max) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_contrast(img, min_max): +def test_get_image_transforms_contrast(img_tensor, min_max): tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max) tf_expected = v2.ColorJitter(contrast=min_max) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_saturation(img, min_max): +def test_get_image_transforms_saturation(img_tensor, min_max): tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max) tf_expected = v2.ColorJitter(saturation=min_max) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) @pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) -def test_get_image_transforms_hue(img, min_max): +def test_get_image_transforms_hue(img_tensor, min_max): tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max) tf_expected = v2.ColorJitter(hue=min_max) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) -def test_get_image_transforms_sharpness(img, min_max): +def test_get_image_transforms_sharpness(img_tensor, min_max): tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max) tf_expected = SharpnessJitter(sharpness=min_max) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) -def test_get_image_transforms_max_num_transforms(img): +def test_get_image_transforms_max_num_transforms(img_tensor): tf_actual = get_image_transforms( brightness_min_max=(0.5, 0.5), contrast_min_max=(0.5, 0.5), @@ -125,11 +124,11 @@ def test_get_image_transforms_max_num_transforms(img): SharpnessJitter(sharpness=(0.5, 0.5)), ] ) - torch.testing.assert_close(tf_actual(img), tf_expected(img)) + torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) @require_x86_64_kernel -def test_get_image_transforms_random_order(img): +def test_get_image_transforms_random_order(img_tensor): out_imgs = [] tf = get_image_transforms( brightness_min_max=(0.5, 0.5), @@ -141,7 +140,7 @@ def test_get_image_transforms_random_order(img): ) with seeded_context(1337): for _ in range(10): - out_imgs.append(tf(img)) + out_imgs.append(tf(img_tensor)) for i in range(1, len(out_imgs)): with pytest.raises(AssertionError): @@ -158,21 +157,21 @@ def test_get_image_transforms_random_order(img): ("sharpness", [(0.5, 0.5), (2.0, 2.0)]), ], ) -def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms): +def test_backward_compatibility_torchvision(transform, min_max_values, img_tensor, single_transforms): for min_max in min_max_values: kwargs = { f"{transform}_weight": 1.0, f"{transform}_min_max": min_max, } tf = get_image_transforms(**kwargs) - actual = tf(img) + actual = tf(img_tensor) key = f"{transform}_{min_max[0]}_{min_max[1]}" expected = single_transforms[key] torch.testing.assert_close(actual, expected) @require_x86_64_kernel -def test_backward_compatibility_default_config(img, default_transforms): +def test_backward_compatibility_default_config(img_tensor, default_transforms): cfg = init_hydra_config(DEFAULT_CONFIG_PATH) cfg_tf = cfg.training.image_transforms default_tf = get_image_transforms( @@ -191,7 +190,7 @@ def test_backward_compatibility_default_config(img, default_transforms): ) with seeded_context(1337): - actual = default_tf(img) + actual = default_tf(img_tensor) expected = default_transforms["default"] @@ -199,33 +198,33 @@ def test_backward_compatibility_default_config(img, default_transforms): @pytest.mark.parametrize("p", [[0, 1], [1, 0]]) -def test_random_subset_apply_single_choice(p, img): +def test_random_subset_apply_single_choice(p, img_tensor): flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) - actual = random_choice(img) + actual = random_choice(img_tensor) p_horz, _ = p if p_horz: - torch.testing.assert_close(actual, F.horizontal_flip(img)) + torch.testing.assert_close(actual, F.horizontal_flip(img_tensor)) else: - torch.testing.assert_close(actual, F.vertical_flip(img)) + torch.testing.assert_close(actual, F.vertical_flip(img_tensor)) -def test_random_subset_apply_random_order(img): +def test_random_subset_apply_random_order(img_tensor): flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) # We can't really check whether the transforms are actually applied in random order. However, # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform # applies them in random order, we can use a fixed order to compute the expected value. - actual = random_order(img) - expected = v2.Compose(flips)(img) + actual = random_order(img_tensor) + expected = v2.Compose(flips)(img_tensor) torch.testing.assert_close(actual, expected) -def test_random_subset_apply_valid_transforms(color_jitters, img): +def test_random_subset_apply_valid_transforms(color_jitters, img_tensor): transform = RandomSubsetApply(color_jitters) - output = transform(img) - assert output.shape == img.shape + output = transform(img_tensor) + assert output.shape == img_tensor.shape def test_random_subset_apply_probability_length_mismatch(color_jitters): @@ -239,16 +238,16 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset): RandomSubsetApply(color_jitters, n_subset=n_subset) -def test_sharpness_jitter_valid_range_tuple(img): +def test_sharpness_jitter_valid_range_tuple(img_tensor): tf = SharpnessJitter((0.1, 2.0)) - output = tf(img) - assert output.shape == img.shape + output = tf(img_tensor) + assert output.shape == img_tensor.shape -def test_sharpness_jitter_valid_range_float(img): +def test_sharpness_jitter_valid_range_float(img_tensor): tf = SharpnessJitter(0.5) - output = tf(img) - assert output.shape == img.shape + output = tf(img_tensor) + assert output.shape == img_tensor.shape def test_sharpness_jitter_invalid_range_min_negative():