Add img and img_tensor factories
This commit is contained in:
parent
293bdc7f67
commit
375abd3020
|
@ -4,7 +4,9 @@ from unittest.mock import patch
|
|||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
|
@ -37,6 +39,14 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
|
|||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
|
||||
return torch.rand((3, height, width), dtype=torch.float32)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(width=100, height=100) -> np.ndarray:
|
||||
|
@ -45,6 +55,15 @@ def img_array_factory():
|
|||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(width=width, height=height)
|
||||
return PIL.Image.Image.fromarray(img_array)
|
||||
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory():
|
||||
def _create_info(
|
||||
|
|
|
@ -15,10 +15,8 @@
|
|||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
@ -32,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
|||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor() -> torch.Tensor:
|
||||
return torch.rand((3, 480, 640), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img() -> Image:
|
||||
img_array = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8)
|
||||
return Image.fromarray(img_array)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
|
@ -66,47 +49,54 @@ def default_transforms():
|
|||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img_tensor):
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img_tensor, min_max):
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img_tensor, min_max):
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img_tensor, min_max):
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img_tensor, min_max):
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img_tensor, min_max):
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor):
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
|
@ -128,8 +118,9 @@ def test_get_image_transforms_max_num_transforms(img_tensor):
|
|||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img_tensor):
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
|
@ -147,6 +138,7 @@ def test_get_image_transforms_random_order(img_tensor):
|
|||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
[
|
||||
|
@ -157,7 +149,8 @@ def test_get_image_transforms_random_order(img_tensor):
|
|||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(transform, min_max_values, img_tensor, single_transforms):
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
|
@ -170,8 +163,10 @@ def test_backward_compatibility_torchvision(transform, min_max_values, img_tenso
|
|||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
|
@ -198,7 +193,8 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(p, img_tensor):
|
||||
def test_random_subset_apply_single_choice(img_tensor_factory, p):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img_tensor)
|
||||
|
@ -210,7 +206,8 @@ def test_random_subset_apply_single_choice(p, img_tensor):
|
|||
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img_tensor):
|
||||
def test_random_subset_apply_random_order(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
|
@ -221,7 +218,8 @@ def test_random_subset_apply_random_order(img_tensor):
|
|||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(color_jitters, img_tensor):
|
||||
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
|
||||
img_tensor = img_tensor_factory()
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
@ -238,13 +236,15 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
|||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img_tensor):
|
||||
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img_tensor):
|
||||
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
@ -260,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
|||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue