Updated config to match transforms

This commit is contained in:
Marina Barannikov 2024-06-04 11:09:23 +00:00
commit 66629a956d
3 changed files with 66 additions and 24 deletions

View File

@ -17,9 +17,9 @@ import logging
import torch import torch
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
from torchvision.transforms import v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
def resolve_delta_timestamps(cfg): def resolve_delta_timestamps(cfg):
@ -72,23 +72,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg) resolve_delta_timestamps(cfg)
if cfg.image_transform.enable: transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
transform = v2.Compose(
[
v2.ColorJitter(
brightness=cfg.image_transform.colorjitter_factor,
contrast=cfg.image_transform.colorjitter_factor,
),
v2.RandomAdjustSharpness(
cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p
),
# Using RandomAdjustSharpness with parameter < 1 adds blur to the image
v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
v2.ToDtype(torch.float32, scale=True),
]
)
else:
transform = None
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset( dataset = LeRobotDataset(

View File

@ -0,0 +1,57 @@
from typing import Any, Callable, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
class RandomSubsetApply(Transform):
"""
Apply a random subset of N transformations from a list of transformations in a random order.
Args:
transforms (sequence or torch.nn.Module): list of transformations
N (int): number of transformations to apply
"""
def __init__(self, transforms: Sequence[Callable], n_subset: int) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if not (0 <= n_subset <= len(transforms)):
raise ValueError(f"N should be in the interval [0, {len(transforms)}]")
self.transforms = transforms
self.N = n_subset
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
# Randomly pick N transforms
selected_transforms = torch.randperm(len(self.transforms))[: self.N]
# Apply selected transforms in random order
for idx in selected_transforms:
transform = self.transforms[idx]
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
format_string = [f"N={self.N}"]
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
def make_transforms(cfg):
image_transforms = []
if 'jit' in cfg.list_transforms:
image_transforms.append(v2.ColorJitter(brightness=cfg.colorjitter_range, contrast=cfg.colorjitter_range))
if 'sharpness' in cfg.list_transforms:
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
if 'blur' in cfg.list_transforms:
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
return v2.Compose(RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True))

View File

@ -60,13 +60,14 @@ wandb:
image_transform: image_transform:
enable: false enable: false
colorjittor_factor: 0.5 colorjittor_range: (0, 1)
# Noise sampled randomly from 0 to colorjittor_factor # Range from which to sample colorjittor factor
sharpness_factor: 2 sharpness_factor: 3
# Should be more than 1, setting parameter to 1 does not change the image # Should be more than 1, setting parameter to 1 does not change the image
sharpness_p: 0.5 sharpness_p: 0.5
# Probability that Sharpness is applied blur_factor: 0.2
blur_factor: 0.5
# Should be less than 1, setting parameter to 1 does not change the image # Should be less than 1, setting parameter to 1 does not change the image
blur_p: 0.5 blur_p: 0.5
# Probability that Blur is applied n_subset: 3
# Maximum number of transforms to apply
list: ["colorjittor", "sharpness", "blur"]