Add img and img_tensor factories

This commit is contained in:
Simon Alibert 2024-11-02 13:06:38 +01:00
parent 293bdc7f67
commit 375abd3020
2 changed files with 52 additions and 32 deletions

View File

@ -4,7 +4,9 @@ from unittest.mock import patch
import datasets import datasets
import numpy as np import numpy as np
import PIL.Image
import pytest import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -37,6 +39,14 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
return task_to_task_index[task] return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
return torch.rand((3, height, width), dtype=torch.float32)
return _create_img_tensor
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def img_array_factory(): def img_array_factory():
def _create_img_array(width=100, height=100) -> np.ndarray: def _create_img_array(width=100, height=100) -> np.ndarray:
@ -45,6 +55,15 @@ def img_array_factory():
return _create_img_array return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(width=100, height=100) -> PIL.Image.Image:
img_array = img_array_factory(width=width, height=height)
return PIL.Image.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def info_factory(): def info_factory():
def _create_info( def _create_info(

View File

@ -15,10 +15,8 @@
# limitations under the License. # limitations under the License.
from pathlib import Path from pathlib import Path
import numpy as np
import pytest import pytest
import torch import torch
from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from torchvision.transforms import v2 from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812 from torchvision.transforms.v2 import functional as F # noqa: N812
@ -32,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img_tensor() -> torch.Tensor:
return torch.rand((3, 480, 640), dtype=torch.float32)
@pytest.fixture
def img() -> Image:
img_array = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8)
return Image.fromarray(img_array)
@pytest.fixture @pytest.fixture
def color_jitters(): def color_jitters():
return [ return [
@ -66,47 +49,54 @@ def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors") return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img_tensor): def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0) tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor) torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img_tensor, min_max): def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max) tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max) tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img_tensor, min_max): def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max) tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max) tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img_tensor, min_max): def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max) tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max) tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) @pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img_tensor, min_max): def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max) tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max) tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) @pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img_tensor, min_max): def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max) tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max) tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img_tensor): def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms( tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5), brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5), contrast_min_max=(0.5, 0.5),
@ -128,8 +118,9 @@ def test_get_image_transforms_max_num_transforms(img_tensor):
@require_x86_64_kernel @require_x86_64_kernel
def test_get_image_transforms_random_order(img_tensor): def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = [] out_imgs = []
img_tensor = img_tensor_factory()
tf = get_image_transforms( tf = get_image_transforms(
brightness_min_max=(0.5, 0.5), brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5), contrast_min_max=(0.5, 0.5),
@ -147,6 +138,7 @@ def test_get_image_transforms_random_order(img_tensor):
torch.testing.assert_close(out_imgs[0], out_imgs[i]) torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"transform, min_max_values", "transform, min_max_values",
[ [
@ -157,7 +149,8 @@ def test_get_image_transforms_random_order(img_tensor):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]), ("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
], ],
) )
def test_backward_compatibility_torchvision(transform, min_max_values, img_tensor, single_transforms): def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
img_tensor = img_tensor_factory()
for min_max in min_max_values: for min_max in min_max_values:
kwargs = { kwargs = {
f"{transform}_weight": 1.0, f"{transform}_weight": 1.0,
@ -170,8 +163,10 @@ def test_backward_compatibility_torchvision(transform, min_max_values, img_tenso
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@require_x86_64_kernel @require_x86_64_kernel
def test_backward_compatibility_default_config(img_tensor, default_transforms): def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
img_tensor = img_tensor_factory()
cfg = init_hydra_config(DEFAULT_CONFIG_PATH) cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms( default_tf = get_image_transforms(
@ -198,7 +193,8 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
@pytest.mark.parametrize("p", [[0, 1], [1, 0]]) @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img_tensor): def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img_tensor) actual = random_choice(img_tensor)
@ -210,7 +206,8 @@ def test_random_subset_apply_single_choice(p, img_tensor):
torch.testing.assert_close(actual, F.vertical_flip(img_tensor)) torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img_tensor): def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However, # We can't really check whether the transforms are actually applied in random order. However,
@ -221,7 +218,8 @@ def test_random_subset_apply_random_order(img_tensor):
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img_tensor): def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters) transform = RandomSubsetApply(color_jitters)
output = transform(img_tensor) output = transform(img_tensor)
assert output.shape == img_tensor.shape assert output.shape == img_tensor.shape
@ -238,13 +236,15 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
RandomSubsetApply(color_jitters, n_subset=n_subset) RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img_tensor): def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0)) tf = SharpnessJitter((0.1, 2.0))
output = tf(img_tensor) output = tf(img_tensor)
assert output.shape == img_tensor.shape assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img_tensor): def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5) tf = SharpnessJitter(0.5)
output = tf(img_tensor) output = tf(img_tensor)
assert output.shape == img_tensor.shape assert output.shape == img_tensor.shape
@ -260,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1)) SharpnessJitter((2.0, 0.1))
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"repo_id, n_examples", "repo_id, n_examples",
[ [