Implemented data augmentation with LeRobot class
This commit is contained in:
parent
265b0ec44d
commit
65e46a49e1
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
@ -47,12 +48,19 @@ def make_dataset(
|
|||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
if cfg.image_transform.enable:
|
||||
transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor),
|
||||
v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
else:
|
||||
transform = None
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
transform=transform
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -148,7 +148,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
for cam in self.camera_keys:
|
||||
item[cam]=self.transform(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
|
|
|
@ -53,3 +53,12 @@ wandb:
|
|||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
image_transform:
|
||||
enable: false
|
||||
colorjittor_factor: 0.5
|
||||
colorjittor_p: 0.5
|
||||
sharpness_factor: 2
|
||||
sharpness_p: 0.5
|
||||
blur_factor: 0.5
|
||||
blur_p: 0.5
|
Loading…
Reference in New Issue