Add img and img_tensor factories

This commit is contained in:
Simon Alibert 2024-11-02 13:06:38 +01:00
parent 293bdc7f67
commit 375abd3020
2 changed files with 52 additions and 32 deletions

View File

@ -4,7 +4,9 @@ from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import (
@ -37,6 +39,14 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
return torch.rand((3, height, width), dtype=torch.float32)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(width=100, height=100) -> np.ndarray:
@ -45,6 +55,15 @@ def img_array_factory():
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(width=100, height=100) -> PIL.Image.Image:
img_array = img_array_factory(width=width, height=height)
return PIL.Image.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def info_factory():
def _create_info(

View File

@ -15,10 +15,8 @@
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
@ -32,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img_tensor() -> torch.Tensor:
return torch.rand((3, 480, 640), dtype=torch.float32)
@pytest.fixture
def img() -> Image:
img_array = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8)
return Image.fromarray(img_array)
@pytest.fixture
def color_jitters():
return [
@ -66,47 +49,54 @@ def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img_tensor):
def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img_tensor, min_max):
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img_tensor, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img_tensor, min_max):
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img_tensor, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img_tensor, min_max):
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img_tensor):
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@ -128,8 +118,9 @@ def test_get_image_transforms_max_num_transforms(img_tensor):
@require_x86_64_kernel
def test_get_image_transforms_random_order(img_tensor):
def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = []
img_tensor = img_tensor_factory()
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@ -147,6 +138,7 @@ def test_get_image_transforms_random_order(img_tensor):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"transform, min_max_values",
[
@ -157,7 +149,8 @@ def test_get_image_transforms_random_order(img_tensor):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img_tensor, single_transforms):
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
img_tensor = img_tensor_factory()
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
@ -170,8 +163,10 @@ def test_backward_compatibility_torchvision(transform, min_max_values, img_tenso
torch.testing.assert_close(actual, expected)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@require_x86_64_kernel
def test_backward_compatibility_default_config(img_tensor, default_transforms):
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
img_tensor = img_tensor_factory()
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
@ -198,7 +193,8 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img_tensor):
def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img_tensor)
@ -210,7 +206,8 @@ def test_random_subset_apply_single_choice(p, img_tensor):
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img_tensor):
def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
@ -221,7 +218,8 @@ def test_random_subset_apply_random_order(img_tensor):
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img_tensor):
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters)
output = transform(img_tensor)
assert output.shape == img_tensor.shape
@ -238,13 +236,15 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img_tensor):
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0))
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img_tensor):
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
@ -260,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id, n_examples",
[