FIx make_dataset to match transforms config (#264)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Marina Barannikov 2024-06-12 19:45:42 +02:00 committed by GitHub
parent ff8f6aa6cd
commit c38f535c9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 13 deletions

View File

@ -10,6 +10,7 @@ on:
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "poetry.lock"
- "Makefile"
push: push:
branches: branches:
- main - main
@ -19,6 +20,7 @@ on:
- "examples/**" - "examples/**"
- ".github/**" - ".github/**"
- "poetry.lock" - "poetry.lock"
- "Makefile"
jobs: jobs:
pytest: pytest:

View File

@ -5,7 +5,7 @@ PYTHON_PATH := $(shell which python)
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python # If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
POETRY_CHECK := $(shell command -v poetry) POETRY_CHECK := $(shell command -v poetry)
ifneq ($(POETRY_CHECK),) ifneq ($(POETRY_CHECK),)
PYTHON_PATH := $(shell poetry run which python) PYTHON_PATH := $(shell poetry run which python)
endif endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH) export PATH := $(dir $(PYTHON_PATH)):$(PATH)
@ -46,6 +46,7 @@ test-act-ete-train:
policy.n_action_steps=20 \ policy.n_action_steps=20 \
policy.chunk_size=20 \ policy.chunk_size=20 \
training.batch_size=2 \ training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act/ hydra.run.dir=tests/outputs/act/
test-act-ete-eval: test-act-ete-eval:
@ -73,6 +74,7 @@ test-act-ete-train-amp:
policy.chunk_size=20 \ policy.chunk_size=20 \
training.batch_size=2 \ training.batch_size=2 \
hydra.run.dir=tests/outputs/act_amp/ \ hydra.run.dir=tests/outputs/act_amp/ \
training.image_transforms.enable=true \
use_amp=true use_amp=true
test-act-ete-eval-amp: test-act-ete-eval-amp:
@ -100,6 +102,7 @@ test-diffusion-ete-train:
training.save_checkpoint=true \ training.save_checkpoint=true \
training.save_freq=2 \ training.save_freq=2 \
training.batch_size=2 \ training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/diffusion/ hydra.run.dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
@ -127,6 +130,7 @@ test-tdmpc-ete-train:
training.save_checkpoint=true \ training.save_checkpoint=true \
training.save_freq=2 \ training.save_freq=2 \
training.batch_size=2 \ training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/tdmpc/ hydra.run.dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
@ -159,5 +163,6 @@ test-act-pusht-tutorial:
training.save_model=true \ training.save_model=true \
training.save_freq=2 \ training.save_freq=2 \
training.batch_size=2 \ training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act_pusht/ hydra.run.dir=tests/outputs/act_pusht/
rm lerobot/configs/policy/created_by_Makefile.yaml rm lerobot/configs/policy/created_by_Makefile.yaml

View File

@ -74,19 +74,20 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
image_transforms = None image_transforms = None
if cfg.training.image_transforms.enable: if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms( image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight, brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg.brightness.min_max, brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg.contrast.weight, contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg.contrast.min_max, contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg.saturation.weight, saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg.saturation.min_max, saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg.hue.weight, hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg.hue.min_max, hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg.sharpness.weight, sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max, sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms, max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg.random_order, random_order=cfg_tf.random_order,
) )
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):