Handle n_subset == 0
This commit is contained in:
parent
9dad7fb0a9
commit
fb0f69ee65
|
@ -42,8 +42,8 @@ class RandomSubsetApply(Transform):
|
|||
n_subset = len(transforms)
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (0 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
|
@ -130,7 +130,7 @@ def get_image_transforms(
|
|||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value_error(name, weight, min_max):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
|
@ -141,27 +141,27 @@ def get_image_transforms(
|
|||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
|
||||
check_value_error("brightness", brightness_weight, brightness_min_max)
|
||||
check_value_error("contrast", contrast_weight, contrast_min_max)
|
||||
check_value_error("saturation", saturation_weight, saturation_min_max)
|
||||
check_value_error("hue", hue_weight, hue_min_max)
|
||||
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None:
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None:
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None:
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None:
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None:
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
|
@ -169,7 +169,8 @@ def get_image_transforms(
|
|||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
|
||||
final_transforms = RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return final_transforms
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
|
|
|
@ -17,12 +17,6 @@ from lerobot.common.utils.utils import seeded_context
|
|||
# - save artifacts
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform():
|
||||
get_image_transforms()
|
||||
get_image_transforms(sharpness_weight=0.0)
|
||||
get_image_transforms(max_num_transforms=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
# dataset = LeRobotDataset("lerobot/pusht")
|
||||
|
@ -33,6 +27,29 @@ def img():
|
|||
return img_chw
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kwargs():
|
||||
return {
|
||||
"brightness_weight": 0.0,
|
||||
"brightness_min_max": (0.5, 0.5),
|
||||
"contrast_weight": 0.0,
|
||||
"contrast_min_max": (0.5, 0.5),
|
||||
"saturation_weight": 0.0,
|
||||
"saturation_min_max": (0.5, 0.5),
|
||||
"hue_weight": 0.0,
|
||||
"hue_min_max": (0.5, 0.5),
|
||||
"sharpness_weight": 0.0,
|
||||
"sharpness_min_max": (0.5, 0.5),
|
||||
"max_num_transforms": 3,
|
||||
"random_order": False,
|
||||
}
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img):
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img), img)
|
||||
|
||||
|
||||
def test_get_image_transforms_brightness(img):
|
||||
brightness_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=brightness_min_max)
|
||||
|
|
Loading…
Reference in New Issue