Implemented data augmentation with LeRobot class

This commit is contained in:
marina.barannikov@huggingface.co 2024-05-31 14:16:38 +00:00
parent 265b0ec44d
commit 65e46a49e1
3 changed files with 20 additions and 2 deletions

View File

@ -16,6 +16,7 @@
import logging
import torch
from torchvision.transforms import v2
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -47,12 +48,19 @@ def make_dataset(
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
if cfg.image_transform.enable:
transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor),
v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
v2.ToDtype(torch.float32, scale=True),
])
else:
transform = None
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform
)
if cfg.get("override_dataset_stats"):

View File

@ -148,7 +148,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
if self.transform is not None:
item = self.transform(item)
for cam in self.camera_keys:
item[cam]=self.transform(item[cam])
return item

View File

@ -53,3 +53,12 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""
image_transform:
enable: false
colorjittor_factor: 0.5
colorjittor_p: 0.5
sharpness_factor: 2
sharpness_p: 0.5
blur_factor: 0.5
blur_p: 0.5