Nest image_transforms config under training

This commit is contained in:
Simon Alibert 2024-06-11 07:34:01 +00:00
parent bde12f6a59
commit 7a097e9e9f
5 changed files with 39 additions and 40 deletions

View File

@ -73,7 +73,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
image_transforms = None
if cfg.image_transforms.enable:
if cfg.training.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,

View File

@ -43,6 +43,39 @@ training:
save_checkpoint: true
num_workers: 4
batch_size: ???
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
enable: false
# This is the number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [0, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]
eval:
n_episodes: 1
@ -57,37 +90,3 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
enable: false
# This is the number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [0, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]

View File

@ -26,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
for transform_name in transforms:
for t in transforms:
if t == transform_name:
cfg.image_transforms[t].weight = 1
cfg.training.image_transforms[t].weight = 1
else:
cfg.image_transforms[t].weight = 0
cfg.training.image_transforms[t].weight = 0
transform = make_image_transforms(cfg.image_transforms)
transform = make_image_transforms(cfg.training.image_transforms)
img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)

View File

@ -12,7 +12,7 @@ from tests.utils import DEFAULT_CONFIG_PATH
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.image_transforms
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,

View File

@ -158,7 +158,7 @@ def test_backward_compatibility_torchvision(transform, img, single_transforms):
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.image_transforms
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,