Implemented data augmentation with LeRobot class
This commit is contained in:
parent
265b0ec44d
commit
65e46a49e1
|
@ -16,6 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
@ -47,12 +48,19 @@ def make_dataset(
|
||||||
|
|
||||||
resolve_delta_timestamps(cfg)
|
resolve_delta_timestamps(cfg)
|
||||||
|
|
||||||
# TODO(rcadene): add data augmentations
|
if cfg.image_transform.enable:
|
||||||
|
transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor),
|
||||||
|
v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
transform = None
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
|
transform=transform
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -148,7 +148,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
item = self.transform(item)
|
for cam in self.camera_keys:
|
||||||
|
item[cam]=self.transform(item[cam])
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
|
@ -53,3 +53,12 @@ wandb:
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
project: lerobot
|
project: lerobot
|
||||||
notes: ""
|
notes: ""
|
||||||
|
|
||||||
|
image_transform:
|
||||||
|
enable: false
|
||||||
|
colorjittor_factor: 0.5
|
||||||
|
colorjittor_p: 0.5
|
||||||
|
sharpness_factor: 2
|
||||||
|
sharpness_p: 0.5
|
||||||
|
blur_factor: 0.5
|
||||||
|
blur_p: 0.5
|
Loading…
Reference in New Issue