Add data augmentation in LeRobotDataset (#234)

Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
Marina Barannikov 2024-06-11 19:20:55 +02:00 committed by GitHub
parent 1cf050d412
commit ff8f6aa6cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 811 additions and 12 deletions

View File

@ -0,0 +1,52 @@
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
"""
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
dataset_repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(dataset_repo_id)
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
# Define the transformations
transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)
# Save the original frame
to_pil = ToPILImage()
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
# Save the transformed frame
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")

View File

@ -19,6 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
def resolve_delta_timestamps(cfg):
@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
image_transforms = None
if cfg.training.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):

View File

@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s,
)
if self.transform is not None:
item = self.transform(item)
if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.image_transforms(item[cam])
return item
@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)

View File

@ -0,0 +1,197 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from typing import Any, Callable, Dict, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations.
Args:
transforms: list of transformations.
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
have the same probability.
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
Must be in [1, len(transforms)].
random_order: apply transformations in a random order.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class SharpnessJitter(Transform):
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
While v2.RandomAdjustSharpness applies with a given probability a fixed sharpness_factor to an image,
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
augmentations as a result.
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
[max(0, 1 - sharpness), 1 + sharpness] or the given
[min, max]. Should be non negative numbers.
"""
def __init__(self, sharpness: float | Sequence[float]) -> None:
super().__init__()
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
def _generate_value(self, left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
)
if weight < 0.0:
raise ValueError(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)
check_value("brightness", brightness_weight, brightness_min_max)
check_value("contrast", contrast_weight, contrast_min_max)
check_value("saturation", saturation_weight, saturation_min_max)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
n_subset = len(transforms)
if max_num_transforms is not None:
n_subset = min(n_subset, max_num_transforms)
if n_subset == 0:
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)

View File

@ -43,6 +43,40 @@ training:
save_checkpoint: true
num_workers: 4
batch_size: ???
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
# Set this flag to `true` to enable transforms during training
enable: false
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]
eval:
n_episodes: 1

View File

@ -0,0 +1,142 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize effects of image transforms for a given configuration.
This script will generate examples of transformed images as they are output by LeRobot dataset.
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
--- Usage Examples ---
Increase hue jitter
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.hue.min_max=[-0.25,0.25]
```
Increase brightness & brightness weight
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.brightness.weight=10.0 \
training.image_transforms.brightness.min_max=[1.0,2.0]
```
Blur images and disable saturation & hue
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.sharpness.weight=10.0 \
training.image_transforms.sharpness.min_max=[0.0,1.0] \
training.image_transforms.saturation.weight=0.0 \
training.image_transforms.hue.weight=0.0
```
Use all transforms with random order
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.max_num_transforms=5 \
training.image_transforms.random_order=true
```
"""
from pathlib import Path
import hydra
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
OUTPUT_DIR = Path("outputs/image_transforms")
N_EXAMPLES = 5
to_pil = ToPILImage()
def save_config_all_transforms(cfg, original_frame, output_dir):
tf = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
for i in range(1, N_EXAMPLES + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
print("Combined transforms examples saved to:")
print(f" {output_dir_all}")
def save_config_single_transforms(cfg, original_frame, output_dir):
transforms = [
"brightness",
"contrast",
"saturation",
"hue",
"sharpness",
]
print("Individual transforms examples saved to:")
for transform in transforms:
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": cfg[f"{transform}"].min_max,
}
tf = get_image_transforms(**kwargs)
output_dir_single = output_dir / f"{transform}"
output_dir_single.mkdir(parents=True, exist_ok=True)
for i in range(1, N_EXAMPLES + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
print(f" {output_dir_single}")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms(cfg):
dataset = LeRobotDataset(cfg.dataset_repo_id)
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
original_frame = dataset[0][dataset.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
if __name__ == "__main__":
visualize_transforms()

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
size 3686488

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
size 40551392

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
with seeded_context(1337):
img_tf = default_tf(original_frame)
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
transforms = {
"brightness": [(0.5, 0.5), (2.0, 2.0)],
"contrast": [(0.5, 0.5), (2.0, 2.0)],
"saturation": [(0.5, 0.5), (2.0, 2.0)],
"hue": [(-0.25, -0.25), (0.25, 0.25)],
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
}
frames = {"original_frame": original_frame}
for transform, values in transforms.items():
for min_max in values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
frames[key] = tf(original_frame)
save_file(frames, output_dir / "single_transforms.safetensors")
def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,260 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img():
dataset = LeRobotDataset(DATASET_REPO_ID)
return dataset[0][dataset.camera_keys[0]]
@pytest.fixture
def img_random():
return torch.rand(3, 480, 640)
@pytest.fixture
def color_jitters():
return [
v2.ColorJitter(brightness=0.5),
v2.ColorJitter(contrast=0.5),
v2.ColorJitter(saturation=0.5),
]
@pytest.fixture
def single_transforms():
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
@pytest.fixture
def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img):
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img, min_max):
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img, min_max):
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img, min_max):
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img, min_max):
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img, min_max):
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_max_num_transforms(img):
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
tf_expected = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
out_imgs = []
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=True,
)
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.parametrize(
"transform, min_max_values",
[
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
actual = tf(img)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
with seeded_context(1337):
actual = default_tf(img)
expected = default_transforms["default"]
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img):
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img))
else:
torch.testing.assert_close(actual, F.vertical_flip(img))
def test_random_subset_apply_random_order(img):
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img)
expected = v2.Compose(flips)(img)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img):
transform = RandomSubsetApply(color_jitters)
output = transform(img)
assert output.shape == img.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
@pytest.mark.parametrize("n_subset", [0, 5])
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img):
tf = SharpnessJitter((0.1, 2.0))
output = tf(img)
assert output.shape == img.shape
def test_sharpness_jitter_valid_range_float(img):
tf = SharpnessJitter(0.5)
output = tf(img)
assert output.shape == img.shape
def test_sharpness_jitter_invalid_range_min_negative():
with pytest.raises(ValueError):
SharpnessJitter((-0.1, 2.0))
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))