FIx make_dataset to match transforms config (#264)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
ff8f6aa6cd
commit
c38f535c9f
|
@ -10,6 +10,7 @@ on:
|
||||||
- "examples/**"
|
- "examples/**"
|
||||||
- ".github/**"
|
- ".github/**"
|
||||||
- "poetry.lock"
|
- "poetry.lock"
|
||||||
|
- "Makefile"
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
@ -19,6 +20,7 @@ on:
|
||||||
- "examples/**"
|
- "examples/**"
|
||||||
- ".github/**"
|
- ".github/**"
|
||||||
- "poetry.lock"
|
- "poetry.lock"
|
||||||
|
- "Makefile"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pytest:
|
pytest:
|
||||||
|
|
5
Makefile
5
Makefile
|
@ -46,6 +46,7 @@ test-act-ete-train:
|
||||||
policy.n_action_steps=20 \
|
policy.n_action_steps=20 \
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/act/
|
hydra.run.dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
|
@ -73,6 +74,7 @@ test-act-ete-train-amp:
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/act_amp/ \
|
hydra.run.dir=tests/outputs/act_amp/ \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
use_amp=true
|
use_amp=true
|
||||||
|
|
||||||
test-act-ete-eval-amp:
|
test-act-ete-eval-amp:
|
||||||
|
@ -100,6 +102,7 @@ test-diffusion-ete-train:
|
||||||
training.save_checkpoint=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
|
@ -127,6 +130,7 @@ test-tdmpc-ete-train:
|
||||||
training.save_checkpoint=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
|
@ -159,5 +163,6 @@ test-act-pusht-tutorial:
|
||||||
training.save_model=true \
|
training.save_model=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
|
training.image_transforms.enable=true \
|
||||||
hydra.run.dir=tests/outputs/act_pusht/
|
hydra.run.dir=tests/outputs/act_pusht/
|
||||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||||
|
|
|
@ -74,19 +74,20 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||||
|
|
||||||
image_transforms = None
|
image_transforms = None
|
||||||
if cfg.training.image_transforms.enable:
|
if cfg.training.image_transforms.enable:
|
||||||
|
cfg_tf = cfg.training.image_transforms
|
||||||
image_transforms = get_image_transforms(
|
image_transforms = get_image_transforms(
|
||||||
brightness_weight=cfg.brightness.weight,
|
brightness_weight=cfg_tf.brightness.weight,
|
||||||
brightness_min_max=cfg.brightness.min_max,
|
brightness_min_max=cfg_tf.brightness.min_max,
|
||||||
contrast_weight=cfg.contrast.weight,
|
contrast_weight=cfg_tf.contrast.weight,
|
||||||
contrast_min_max=cfg.contrast.min_max,
|
contrast_min_max=cfg_tf.contrast.min_max,
|
||||||
saturation_weight=cfg.saturation.weight,
|
saturation_weight=cfg_tf.saturation.weight,
|
||||||
saturation_min_max=cfg.saturation.min_max,
|
saturation_min_max=cfg_tf.saturation.min_max,
|
||||||
hue_weight=cfg.hue.weight,
|
hue_weight=cfg_tf.hue.weight,
|
||||||
hue_min_max=cfg.hue.min_max,
|
hue_min_max=cfg_tf.hue.min_max,
|
||||||
sharpness_weight=cfg.sharpness.weight,
|
sharpness_weight=cfg_tf.sharpness.weight,
|
||||||
sharpness_min_max=cfg.sharpness.min_max,
|
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||||
max_num_transforms=cfg.max_num_transforms,
|
max_num_transforms=cfg_tf.max_num_transforms,
|
||||||
random_order=cfg.random_order,
|
random_order=cfg_tf.random_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cfg.dataset_repo_id, str):
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
|
|
Loading…
Reference in New Issue