diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py
new file mode 100644
index 00000000..bdcc6d7b
--- /dev/null
+++ b/examples/6_add_image_transforms.py
@@ -0,0 +1,52 @@
+"""
+This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
+augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
+transforms are applied to the observation images before they are returned in the dataset's __get_item__.
+"""
+
+from pathlib import Path
+
+from torchvision.transforms import ToPILImage, v2
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+
+dataset_repo_id = "lerobot/aloha_static_tape"
+
+# Create a LeRobotDataset with no transformations
+dataset = LeRobotDataset(dataset_repo_id)
+# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
+
+# Get the index of the first observation in the first episode
+first_idx = dataset.episode_data_index["from"][0].item()
+
+# Get the frame corresponding to the first camera
+frame = dataset[first_idx][dataset.camera_keys[0]]
+
+
+# Define the transformations
+transforms = v2.Compose(
+    [
+        v2.ColorJitter(brightness=(0.5, 1.5)),
+        v2.ColorJitter(contrast=(0.5, 1.5)),
+        v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
+    ]
+)
+
+# Create another LeRobotDataset with the defined transformations
+transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
+
+# Get a frame from the transformed dataset
+transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
+
+# Create a directory to store output images
+output_dir = Path("outputs/image_transforms")
+output_dir.mkdir(parents=True, exist_ok=True)
+
+# Save the original frame
+to_pil = ToPILImage()
+to_pil(frame).save(output_dir / "original_frame.png", quality=100)
+print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
+
+# Save the transformed frame
+to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
+print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index 4732f577..fab8ca57 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -19,6 +19,7 @@ import torch
 from omegaconf import ListConfig, OmegaConf
 
 from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
+from lerobot.common.datasets.transforms import get_image_transforms
 
 
 def resolve_delta_timestamps(cfg):
@@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
 
     resolve_delta_timestamps(cfg)
 
-    # TODO(rcadene): add data augmentations
+    image_transforms = None
+    if cfg.training.image_transforms.enable:
+        image_transforms = get_image_transforms(
+            brightness_weight=cfg.brightness.weight,
+            brightness_min_max=cfg.brightness.min_max,
+            contrast_weight=cfg.contrast.weight,
+            contrast_min_max=cfg.contrast.min_max,
+            saturation_weight=cfg.saturation.weight,
+            saturation_min_max=cfg.saturation.min_max,
+            hue_weight=cfg.hue.weight,
+            hue_min_max=cfg.hue.min_max,
+            sharpness_weight=cfg.sharpness.weight,
+            sharpness_min_max=cfg.sharpness.min_max,
+            max_num_transforms=cfg.max_num_transforms,
+            random_order=cfg.random_order,
+        )
 
     if isinstance(cfg.dataset_repo_id, str):
         dataset = LeRobotDataset(
             cfg.dataset_repo_id,
             split=split,
             delta_timestamps=cfg.training.get("delta_timestamps"),
+            image_transforms=image_transforms,
         )
     else:
         dataset = MultiLeRobotDataset(
-            cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
+            cfg.dataset_repo_id,
+            split=split,
+            delta_timestamps=cfg.training.get("delta_timestamps"),
+            image_transforms=image_transforms,
         )
 
     if cfg.get("override_dataset_stats"):
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index 58ae51b1..d680b987 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         version: str | None = CODEBASE_VERSION,
         root: Path | None = DATA_DIR,
         split: str = "train",
-        transform: Callable | None = None,
+        image_transforms: Callable | None = None,
         delta_timestamps: dict[list[float]] | None = None,
     ):
         super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         self.version = version
         self.root = root
         self.split = split
-        self.transform = transform
+        self.image_transforms = image_transforms
         self.delta_timestamps = delta_timestamps
         # load data from hub or locally when root is provided
         # TODO(rcadene, aliberts): implement faster transfer
@@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
                 self.tolerance_s,
             )
 
-        if self.transform is not None:
-            item = self.transform(item)
+        if self.image_transforms is not None:
+            for cam in self.camera_keys:
+                item[cam] = self.image_transforms(item[cam])
 
         return item
 
@@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
             f"  Recorded Frames per Second: {self.fps},\n"
             f"  Camera Keys: {self.camera_keys},\n"
             f"  Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
-            f"  Transformations: {self.transform},\n"
+            f"  Transformations: {self.image_transforms},\n"
             f")"
         )
 
@@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
         obj.version = version
         obj.root = root
         obj.split = split
-        obj.transform = transform
+        obj.image_transforms = transform
         obj.delta_timestamps = delta_timestamps
         obj.hf_dataset = hf_dataset
         obj.episode_data_index = episode_data_index
@@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
         version: str | None = CODEBASE_VERSION,
         root: Path | None = DATA_DIR,
         split: str = "train",
-        transform: Callable | None = None,
+        image_transforms: Callable | None = None,
         delta_timestamps: dict[list[float]] | None = None,
     ):
         super().__init__()
@@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
                 root=root,
                 split=split,
                 delta_timestamps=delta_timestamps,
-                transform=transform,
+                image_transforms=image_transforms,
             )
             for repo_id in repo_ids
         ]
@@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
         self.version = version
         self.root = root
         self.split = split
-        self.transform = transform
+        self.image_transforms = image_transforms
         self.delta_timestamps = delta_timestamps
         self.stats = aggregate_stats(self._datasets)
 
@@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
         for data_key in self.disabled_data_keys:
             if data_key in item:
                 del item[data_key]
+
         return item
 
     def __repr__(self):
@@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
             f"  Recorded Frames per Second: {self.fps},\n"
             f"  Camera Keys: {self.camera_keys},\n"
             f"  Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
-            f"  Transformations: {self.transform},\n"
+            f"  Transformations: {self.image_transforms},\n"
             f")"
         )
diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py
new file mode 100644
index 00000000..899f0d66
--- /dev/null
+++ b/lerobot/common/datasets/transforms.py
@@ -0,0 +1,197 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+from typing import Any, Callable, Dict, Sequence
+
+import torch
+from torchvision.transforms import v2
+from torchvision.transforms.v2 import Transform
+from torchvision.transforms.v2 import functional as F  # noqa: N812
+
+
+class RandomSubsetApply(Transform):
+    """Apply a random subset of N transformations from a list of transformations.
+
+    Args:
+        transforms: list of transformations.
+        p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
+            If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
+            have the same probability.
+        n_subset: number of transformations to apply. If ``None``, all transforms are applied.
+            Must be in [1, len(transforms)].
+        random_order: apply transformations in a random order.
+    """
+
+    def __init__(
+        self,
+        transforms: Sequence[Callable],
+        p: list[float] | None = None,
+        n_subset: int | None = None,
+        random_order: bool = False,
+    ) -> None:
+        super().__init__()
+        if not isinstance(transforms, Sequence):
+            raise TypeError("Argument transforms should be a sequence of callables")
+        if p is None:
+            p = [1] * len(transforms)
+        elif len(p) != len(transforms):
+            raise ValueError(
+                f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
+            )
+
+        if n_subset is None:
+            n_subset = len(transforms)
+        elif not isinstance(n_subset, int):
+            raise TypeError("n_subset should be an int or None")
+        elif not (1 <= n_subset <= len(transforms)):
+            raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
+
+        self.transforms = transforms
+        total = sum(p)
+        self.p = [prob / total for prob in p]
+        self.n_subset = n_subset
+        self.random_order = random_order
+
+    def forward(self, *inputs: Any) -> Any:
+        needs_unpacking = len(inputs) > 1
+
+        selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
+        if not self.random_order:
+            selected_indices = selected_indices.sort().values
+
+        selected_transforms = [self.transforms[i] for i in selected_indices]
+
+        for transform in selected_transforms:
+            outputs = transform(*inputs)
+            inputs = outputs if needs_unpacking else (outputs,)
+
+        return outputs
+
+    def extra_repr(self) -> str:
+        return (
+            f"transforms={self.transforms}, "
+            f"p={self.p}, "
+            f"n_subset={self.n_subset}, "
+            f"random_order={self.random_order}"
+        )
+
+
+class SharpnessJitter(Transform):
+    """Randomly change the sharpness of an image or video.
+
+    Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
+    While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
+    SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
+    augmentations as a result.
+
+    A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
+    by a factor of 2.
+
+    If the input is a :class:`torch.Tensor`,
+    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+    Args:
+        sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
+            [max(0, 1 - sharpness), 1 + sharpness] or the given
+            [min, max]. Should be non negative numbers.
+    """
+
+    def __init__(self, sharpness: float | Sequence[float]) -> None:
+        super().__init__()
+        self.sharpness = self._check_input(sharpness)
+
+    def _check_input(self, sharpness):
+        if isinstance(sharpness, (int, float)):
+            if sharpness < 0:
+                raise ValueError("If sharpness is a single number, it must be non negative.")
+            sharpness = [1.0 - sharpness, 1.0 + sharpness]
+            sharpness[0] = max(sharpness[0], 0.0)
+        elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
+            sharpness = [float(v) for v in sharpness]
+        else:
+            raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
+
+        if not 0.0 <= sharpness[0] <= sharpness[1]:
+            raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
+
+        return float(sharpness[0]), float(sharpness[1])
+
+    def _generate_value(self, left: float, right: float) -> float:
+        return torch.empty(1).uniform_(left, right).item()
+
+    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
+        sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
+        return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
+
+
+def get_image_transforms(
+    brightness_weight: float = 1.0,
+    brightness_min_max: tuple[float, float] | None = None,
+    contrast_weight: float = 1.0,
+    contrast_min_max: tuple[float, float] | None = None,
+    saturation_weight: float = 1.0,
+    saturation_min_max: tuple[float, float] | None = None,
+    hue_weight: float = 1.0,
+    hue_min_max: tuple[float, float] | None = None,
+    sharpness_weight: float = 1.0,
+    sharpness_min_max: tuple[float, float] | None = None,
+    max_num_transforms: int | None = None,
+    random_order: bool = False,
+):
+    def check_value(name, weight, min_max):
+        if min_max is not None:
+            if len(min_max) != 2:
+                raise ValueError(
+                    f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
+                )
+            if weight < 0.0:
+                raise ValueError(
+                    f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
+                )
+
+    check_value("brightness", brightness_weight, brightness_min_max)
+    check_value("contrast", contrast_weight, contrast_min_max)
+    check_value("saturation", saturation_weight, saturation_min_max)
+    check_value("hue", hue_weight, hue_min_max)
+    check_value("sharpness", sharpness_weight, sharpness_min_max)
+
+    weights = []
+    transforms = []
+    if brightness_min_max is not None and brightness_weight > 0.0:
+        weights.append(brightness_weight)
+        transforms.append(v2.ColorJitter(brightness=brightness_min_max))
+    if contrast_min_max is not None and contrast_weight > 0.0:
+        weights.append(contrast_weight)
+        transforms.append(v2.ColorJitter(contrast=contrast_min_max))
+    if saturation_min_max is not None and saturation_weight > 0.0:
+        weights.append(saturation_weight)
+        transforms.append(v2.ColorJitter(saturation=saturation_min_max))
+    if hue_min_max is not None and hue_weight > 0.0:
+        weights.append(hue_weight)
+        transforms.append(v2.ColorJitter(hue=hue_min_max))
+    if sharpness_min_max is not None and sharpness_weight > 0.0:
+        weights.append(sharpness_weight)
+        transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
+
+    n_subset = len(transforms)
+    if max_num_transforms is not None:
+        n_subset = min(n_subset, max_num_transforms)
+
+    if n_subset == 0:
+        return v2.Identity()
+    else:
+        # TODO(rcadene, aliberts): add v2.ToDtype float16?
+        return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml
index 85b9ceea..6101df89 100644
--- a/lerobot/configs/default.yaml
+++ b/lerobot/configs/default.yaml
@@ -43,6 +43,40 @@ training:
   save_checkpoint: true
   num_workers: 4
   batch_size: ???
+  image_transforms:
+  # These transforms are all using standard torchvision.transforms.v2
+  # You can find out how these transformations affect images here:
+  # https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
+  # We use a custom RandomSubsetApply container to sample them.
+  # For each transform, the following parameters are available:
+  #   weight: This represents the multinomial probability (with no replacement)
+  #           used for sampling the transform. If the sum of the weights is not 1,
+  #           they will be normalized.
+  #   min_max: Lower & upper bound respectively used for sampling the transform's parameter
+  #           (following uniform distribution) when it's applied.
+    # Set this flag to `true` to enable transforms during training
+    enable: false
+    # This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
+    # It's an integer in the interval [1, number of available transforms].
+    max_num_transforms: 3
+    # By default, transforms are applied in Torchvision's suggested order (shown below).
+    # Set this to True to apply them in a random order.
+    random_order: false
+    brightness:
+      weight: 1
+      min_max: [0.8, 1.2]
+    contrast:
+      weight: 1
+      min_max: [0.8, 1.2]
+    saturation:
+      weight: 1
+      min_max: [0.5, 1.5]
+    hue:
+      weight: 1
+      min_max: [-0.05, 0.05]
+    sharpness:
+      weight: 1
+      min_max: [0.8, 1.2]
 
 eval:
   n_episodes: 1
diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py
new file mode 100644
index 00000000..fa3c0ab2
--- /dev/null
+++ b/lerobot/scripts/visualize_image_transforms.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Visualize effects of image transforms for a given configuration.
+
+This script will generate examples of transformed images as they are output by LeRobot dataset.
+Additionally, each individual transform can be visualized separately as well as examples of combined transforms
+
+
+--- Usage Examples ---
+
+Increase hue jitter
+```
+python lerobot/scripts/visualize_image_transforms.py \
+    dataset_repo_id=lerobot/aloha_mobile_shrimp \
+    training.image_transforms.hue.min_max=[-0.25,0.25]
+```
+
+Increase brightness & brightness weight
+```
+python lerobot/scripts/visualize_image_transforms.py \
+    dataset_repo_id=lerobot/aloha_mobile_shrimp \
+    training.image_transforms.brightness.weight=10.0 \
+    training.image_transforms.brightness.min_max=[1.0,2.0]
+```
+
+Blur images and disable saturation & hue
+```
+python lerobot/scripts/visualize_image_transforms.py \
+    dataset_repo_id=lerobot/aloha_mobile_shrimp \
+    training.image_transforms.sharpness.weight=10.0 \
+    training.image_transforms.sharpness.min_max=[0.0,1.0] \
+    training.image_transforms.saturation.weight=0.0 \
+    training.image_transforms.hue.weight=0.0
+```
+
+Use all transforms with random order
+```
+python lerobot/scripts/visualize_image_transforms.py \
+    dataset_repo_id=lerobot/aloha_mobile_shrimp \
+    training.image_transforms.max_num_transforms=5 \
+    training.image_transforms.random_order=true
+```
+
+"""
+
+from pathlib import Path
+
+import hydra
+from torchvision.transforms import ToPILImage
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.transforms import get_image_transforms
+
+OUTPUT_DIR = Path("outputs/image_transforms")
+N_EXAMPLES = 5
+to_pil = ToPILImage()
+
+
+def save_config_all_transforms(cfg, original_frame, output_dir):
+    tf = get_image_transforms(
+        brightness_weight=cfg.brightness.weight,
+        brightness_min_max=cfg.brightness.min_max,
+        contrast_weight=cfg.contrast.weight,
+        contrast_min_max=cfg.contrast.min_max,
+        saturation_weight=cfg.saturation.weight,
+        saturation_min_max=cfg.saturation.min_max,
+        hue_weight=cfg.hue.weight,
+        hue_min_max=cfg.hue.min_max,
+        sharpness_weight=cfg.sharpness.weight,
+        sharpness_min_max=cfg.sharpness.min_max,
+        max_num_transforms=cfg.max_num_transforms,
+        random_order=cfg.random_order,
+    )
+
+    output_dir_all = output_dir / "all"
+    output_dir_all.mkdir(parents=True, exist_ok=True)
+
+    for i in range(1, N_EXAMPLES + 1):
+        transformed_frame = tf(original_frame)
+        to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
+
+    print("Combined transforms examples saved to:")
+    print(f"    {output_dir_all}")
+
+
+def save_config_single_transforms(cfg, original_frame, output_dir):
+    transforms = [
+        "brightness",
+        "contrast",
+        "saturation",
+        "hue",
+        "sharpness",
+    ]
+    print("Individual transforms examples saved to:")
+    for transform in transforms:
+        kwargs = {
+            f"{transform}_weight": cfg[f"{transform}"].weight,
+            f"{transform}_min_max": cfg[f"{transform}"].min_max,
+        }
+        tf = get_image_transforms(**kwargs)
+        output_dir_single = output_dir / f"{transform}"
+        output_dir_single.mkdir(parents=True, exist_ok=True)
+
+        for i in range(1, N_EXAMPLES + 1):
+            transformed_frame = tf(original_frame)
+            to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
+
+        print(f"    {output_dir_single}")
+
+
+@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
+def visualize_transforms(cfg):
+    dataset = LeRobotDataset(cfg.dataset_repo_id)
+
+    output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    # Get 1st frame from 1st camera of 1st episode
+    original_frame = dataset[0][dataset.camera_keys[0]]
+    to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
+    print("\nOriginal frame saved to:")
+    print(f"    {output_dir / 'original_frame.png'}.")
+
+    save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
+    save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
+
+
+if __name__ == "__main__":
+    visualize_transforms()
diff --git a/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors
new file mode 100644
index 00000000..77699dab
--- /dev/null
+++ b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
+size 3686488
diff --git a/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors
new file mode 100644
index 00000000..13f1033f
--- /dev/null
+++ b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
+size 40551392
diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py
new file mode 100644
index 00000000..9d024a01
--- /dev/null
+++ b/tests/scripts/save_image_transforms_to_safetensors.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+
+import torch
+from safetensors.torch import save_file
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.transforms import get_image_transforms
+from lerobot.common.utils.utils import init_hydra_config, seeded_context
+from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
+from tests.utils import DEFAULT_CONFIG_PATH
+
+
+def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
+    cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
+    cfg_tf = cfg.training.image_transforms
+    default_tf = get_image_transforms(
+        brightness_weight=cfg_tf.brightness.weight,
+        brightness_min_max=cfg_tf.brightness.min_max,
+        contrast_weight=cfg_tf.contrast.weight,
+        contrast_min_max=cfg_tf.contrast.min_max,
+        saturation_weight=cfg_tf.saturation.weight,
+        saturation_min_max=cfg_tf.saturation.min_max,
+        hue_weight=cfg_tf.hue.weight,
+        hue_min_max=cfg_tf.hue.min_max,
+        sharpness_weight=cfg_tf.sharpness.weight,
+        sharpness_min_max=cfg_tf.sharpness.min_max,
+        max_num_transforms=cfg_tf.max_num_transforms,
+        random_order=cfg_tf.random_order,
+    )
+
+    with seeded_context(1337):
+        img_tf = default_tf(original_frame)
+
+    save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
+
+
+def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
+    transforms = {
+        "brightness": [(0.5, 0.5), (2.0, 2.0)],
+        "contrast": [(0.5, 0.5), (2.0, 2.0)],
+        "saturation": [(0.5, 0.5), (2.0, 2.0)],
+        "hue": [(-0.25, -0.25), (0.25, 0.25)],
+        "sharpness": [(0.5, 0.5), (2.0, 2.0)],
+    }
+
+    frames = {"original_frame": original_frame}
+    for transform, values in transforms.items():
+        for min_max in values:
+            kwargs = {
+                f"{transform}_weight": 1.0,
+                f"{transform}_min_max": min_max,
+            }
+            tf = get_image_transforms(**kwargs)
+            key = f"{transform}_{min_max[0]}_{min_max[1]}"
+            frames[key] = tf(original_frame)
+
+    save_file(frames, output_dir / "single_transforms.safetensors")
+
+
+def main():
+    dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
+    output_dir = Path(ARTIFACT_DIR)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    original_frame = dataset[0][dataset.camera_keys[0]]
+
+    save_single_transforms(original_frame, output_dir)
+    save_default_config_transform(original_frame, output_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py
new file mode 100644
index 00000000..ba6d972f
--- /dev/null
+++ b/tests/test_image_transforms.py
@@ -0,0 +1,260 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from safetensors.torch import load_file
+from torchvision.transforms import v2
+from torchvision.transforms.v2 import functional as F  # noqa: N812
+
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
+from lerobot.common.utils.utils import init_hydra_config, seeded_context
+from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
+
+ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
+DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
+
+
+def load_png_to_tensor(path: Path):
+    return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
+
+
+@pytest.fixture
+def img():
+    dataset = LeRobotDataset(DATASET_REPO_ID)
+    return dataset[0][dataset.camera_keys[0]]
+
+
+@pytest.fixture
+def img_random():
+    return torch.rand(3, 480, 640)
+
+
+@pytest.fixture
+def color_jitters():
+    return [
+        v2.ColorJitter(brightness=0.5),
+        v2.ColorJitter(contrast=0.5),
+        v2.ColorJitter(saturation=0.5),
+    ]
+
+
+@pytest.fixture
+def single_transforms():
+    return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
+
+
+@pytest.fixture
+def default_transforms():
+    return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
+
+
+def test_get_image_transforms_no_transform(img):
+    tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
+    torch.testing.assert_close(tf_actual(img), img)
+
+
+@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
+def test_get_image_transforms_brightness(img, min_max):
+    tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
+    tf_expected = v2.ColorJitter(brightness=min_max)
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
+def test_get_image_transforms_contrast(img, min_max):
+    tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
+    tf_expected = v2.ColorJitter(contrast=min_max)
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
+def test_get_image_transforms_saturation(img, min_max):
+    tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
+    tf_expected = v2.ColorJitter(saturation=min_max)
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
+def test_get_image_transforms_hue(img, min_max):
+    tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
+    tf_expected = v2.ColorJitter(hue=min_max)
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
+def test_get_image_transforms_sharpness(img, min_max):
+    tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
+    tf_expected = SharpnessJitter(sharpness=min_max)
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+def test_get_image_transforms_max_num_transforms(img):
+    tf_actual = get_image_transforms(
+        brightness_min_max=(0.5, 0.5),
+        contrast_min_max=(0.5, 0.5),
+        saturation_min_max=(0.5, 0.5),
+        hue_min_max=(0.5, 0.5),
+        sharpness_min_max=(0.5, 0.5),
+        random_order=False,
+    )
+    tf_expected = v2.Compose(
+        [
+            v2.ColorJitter(brightness=(0.5, 0.5)),
+            v2.ColorJitter(contrast=(0.5, 0.5)),
+            v2.ColorJitter(saturation=(0.5, 0.5)),
+            v2.ColorJitter(hue=(0.5, 0.5)),
+            SharpnessJitter(sharpness=(0.5, 0.5)),
+        ]
+    )
+    torch.testing.assert_close(tf_actual(img), tf_expected(img))
+
+
+@require_x86_64_kernel
+def test_get_image_transforms_random_order(img):
+    out_imgs = []
+    tf = get_image_transforms(
+        brightness_min_max=(0.5, 0.5),
+        contrast_min_max=(0.5, 0.5),
+        saturation_min_max=(0.5, 0.5),
+        hue_min_max=(0.5, 0.5),
+        sharpness_min_max=(0.5, 0.5),
+        random_order=True,
+    )
+    with seeded_context(1337):
+        for _ in range(10):
+            out_imgs.append(tf(img))
+
+    for i in range(1, len(out_imgs)):
+        with pytest.raises(AssertionError):
+            torch.testing.assert_close(out_imgs[0], out_imgs[i])
+
+
+@pytest.mark.parametrize(
+    "transform, min_max_values",
+    [
+        ("brightness", [(0.5, 0.5), (2.0, 2.0)]),
+        ("contrast", [(0.5, 0.5), (2.0, 2.0)]),
+        ("saturation", [(0.5, 0.5), (2.0, 2.0)]),
+        ("hue", [(-0.25, -0.25), (0.25, 0.25)]),
+        ("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
+    ],
+)
+def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
+    for min_max in min_max_values:
+        kwargs = {
+            f"{transform}_weight": 1.0,
+            f"{transform}_min_max": min_max,
+        }
+        tf = get_image_transforms(**kwargs)
+        actual = tf(img)
+        key = f"{transform}_{min_max[0]}_{min_max[1]}"
+        expected = single_transforms[key]
+        torch.testing.assert_close(actual, expected)
+
+
+@require_x86_64_kernel
+def test_backward_compatibility_default_config(img, default_transforms):
+    cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
+    cfg_tf = cfg.training.image_transforms
+    default_tf = get_image_transforms(
+        brightness_weight=cfg_tf.brightness.weight,
+        brightness_min_max=cfg_tf.brightness.min_max,
+        contrast_weight=cfg_tf.contrast.weight,
+        contrast_min_max=cfg_tf.contrast.min_max,
+        saturation_weight=cfg_tf.saturation.weight,
+        saturation_min_max=cfg_tf.saturation.min_max,
+        hue_weight=cfg_tf.hue.weight,
+        hue_min_max=cfg_tf.hue.min_max,
+        sharpness_weight=cfg_tf.sharpness.weight,
+        sharpness_min_max=cfg_tf.sharpness.min_max,
+        max_num_transforms=cfg_tf.max_num_transforms,
+        random_order=cfg_tf.random_order,
+    )
+
+    with seeded_context(1337):
+        actual = default_tf(img)
+
+    expected = default_transforms["default"]
+
+    torch.testing.assert_close(actual, expected)
+
+
+@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
+def test_random_subset_apply_single_choice(p, img):
+    flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
+    random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
+    actual = random_choice(img)
+
+    p_horz, _ = p
+    if p_horz:
+        torch.testing.assert_close(actual, F.horizontal_flip(img))
+    else:
+        torch.testing.assert_close(actual, F.vertical_flip(img))
+
+
+def test_random_subset_apply_random_order(img):
+    flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
+    random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
+    # We can't really check whether the transforms are actually applied in random order. However,
+    # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
+    # applies them in random order, we can use a fixed order to compute the expected value.
+    actual = random_order(img)
+    expected = v2.Compose(flips)(img)
+    torch.testing.assert_close(actual, expected)
+
+
+def test_random_subset_apply_valid_transforms(color_jitters, img):
+    transform = RandomSubsetApply(color_jitters)
+    output = transform(img)
+    assert output.shape == img.shape
+
+
+def test_random_subset_apply_probability_length_mismatch(color_jitters):
+    with pytest.raises(ValueError):
+        RandomSubsetApply(color_jitters, p=[0.5, 0.5])
+
+
+@pytest.mark.parametrize("n_subset", [0, 5])
+def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
+    with pytest.raises(ValueError):
+        RandomSubsetApply(color_jitters, n_subset=n_subset)
+
+
+def test_sharpness_jitter_valid_range_tuple(img):
+    tf = SharpnessJitter((0.1, 2.0))
+    output = tf(img)
+    assert output.shape == img.shape
+
+
+def test_sharpness_jitter_valid_range_float(img):
+    tf = SharpnessJitter(0.5)
+    output = tf(img)
+    assert output.shape == img.shape
+
+
+def test_sharpness_jitter_invalid_range_min_negative():
+    with pytest.raises(ValueError):
+        SharpnessJitter((-0.1, 2.0))
+
+
+def test_sharpness_jitter_invalid_range_max_smaller():
+    with pytest.raises(ValueError):
+        SharpnessJitter((2.0, 0.1))