Handle n_subset == 0

This commit is contained in:
Simon Alibert 2024-06-10 09:56:58 +02:00
parent 9dad7fb0a9
commit fb0f69ee65
2 changed files with 41 additions and 23 deletions

View File

@ -42,8 +42,8 @@ class RandomSubsetApply(Transform):
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (0 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
@ -130,7 +130,7 @@ def get_image_transforms(
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value_error(name, weight, min_max):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
@ -141,27 +141,27 @@ def get_image_transforms(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)
check_value_error("brightness", brightness_weight, brightness_min_max)
check_value_error("contrast", contrast_weight, contrast_min_max)
check_value_error("saturation", saturation_weight, saturation_min_max)
check_value_error("hue", hue_weight, hue_min_max)
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
check_value("brightness", brightness_weight, brightness_min_max)
check_value("contrast", contrast_weight, contrast_min_max)
check_value("saturation", saturation_weight, saturation_min_max)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None:
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None:
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None:
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None:
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None:
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
@ -169,7 +169,8 @@ def get_image_transforms(
if max_num_transforms is not None:
n_subset = min(n_subset, max_num_transforms)
final_transforms = RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return final_transforms
if n_subset == 0:
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)

View File

@ -17,12 +17,6 @@ from lerobot.common.utils.utils import seeded_context
# - save artifacts
def test_get_image_transforms_no_transform():
get_image_transforms()
get_image_transforms(sharpness_weight=0.0)
get_image_transforms(max_num_transforms=0)
@pytest.fixture
def img():
# dataset = LeRobotDataset("lerobot/pusht")
@ -33,6 +27,29 @@ def img():
return img_chw
@pytest.fixture
def kwargs():
return {
"brightness_weight": 0.0,
"brightness_min_max": (0.5, 0.5),
"contrast_weight": 0.0,
"contrast_min_max": (0.5, 0.5),
"saturation_weight": 0.0,
"saturation_min_max": (0.5, 0.5),
"hue_weight": 0.0,
"hue_min_max": (0.5, 0.5),
"sharpness_weight": 0.0,
"sharpness_min_max": (0.5, 0.5),
"max_num_transforms": 3,
"random_order": False,
}
def test_get_image_transforms_no_transform(img):
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
def test_get_image_transforms_brightness(img):
brightness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=brightness_min_max)