From 65e46a49e17bff4cc40b736264c25ee8788735ff Mon Sep 17 00:00:00 2001 From: "marina.barannikov@huggingface.co" Date: Fri, 31 May 2024 14:16:38 +0000 Subject: [PATCH] Implemented data augmentation with LeRobot class --- lerobot/common/datasets/factory.py | 10 +++++++++- lerobot/common/datasets/lerobot_dataset.py | 3 ++- lerobot/configs/default.yaml | 9 +++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 7bdc2ca9..0bd09ac4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -16,6 +16,7 @@ import logging import torch +from torchvision.transforms import v2 from omegaconf import OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -47,12 +48,19 @@ def make_dataset( resolve_delta_timestamps(cfg) - # TODO(rcadene): add data augmentations + if cfg.image_transform.enable: + transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor), + v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p), + v2.ToDtype(torch.float32, scale=True), + ]) + else: + transform = None dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), + transform=transform ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 057e4770..d51e1fba 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -148,7 +148,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ) if self.transform is not None: - item = self.transform(item) + for cam in self.camera_keys: + item[cam]=self.transform(item[cam]) return item diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index f2238769..6ceca699 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -53,3 +53,12 @@ wandb: disable_artifact: false project: lerobot notes: "" + +image_transform: + enable: false + colorjittor_factor: 0.5 + colorjittor_p: 0.5 + sharpness_factor: 2 + sharpness_p: 0.5 + blur_factor: 0.5 + blur_p: 0.5 \ No newline at end of file