diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index a7cb6374..1aa0bc2d 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -165,15 +165,6 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D version_base="1.2", ) cfg = hydra.compose(Path(config_path).stem, overrides) - if cfg.eval.batch_size > cfg.eval.n_episodes: - raise ValueError( - "The eval batch size is greater than the number of eval episodes " - f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} " - f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. " - "This might significantly slow down evaluation. To fix this, you should update your command " - f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), " - f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)." - ) return cfg diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 980c373c..482af786 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -454,6 +454,16 @@ def main( else: hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) + if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes: + raise ValueError( + "The eval batch size is greater than the number of eval episodes " + f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} " + f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. " + "This might significantly slow down evaluation. To fix this, you should update your command " + f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), " + f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)." + ) + if out_dir is None: out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2fa7ae80..45807503 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -288,6 +288,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No "you meant to resume training, please use `resume=true` in your command or yaml configuration." ) + if cfg.eval.batch_size > cfg.eval.n_episodes: + raise ValueError( + "The eval batch size is greater than the number of eval episodes " + f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} " + f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. " + "This might significantly slow down evaluation. To fix this, you should update your command " + f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), " + f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)." + ) + # log metrics to terminal and wandb logger = Logger(cfg, out_dir, wandb_job_name=job_name) diff --git a/tests/test_utils.py b/tests/test_utils.py index d4a8e34a..e5ba2267 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import random from typing import Callable +from uuid import uuid4 import numpy as np import pytest @@ -13,6 +14,7 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.utils.utils import ( get_global_random_state, + init_hydra_config, seeded_context, set_global_random_state, set_global_seed, @@ -83,3 +85,10 @@ def test_reset_episode_index(): correct_episode_index = [0, 0, 1, 2, 2, 2] dataset = reset_episode_index(dataset) assert dataset["episode_index"] == correct_episode_index + + +def test_init_hydra_config_empty(): + test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml" + with open(test_file, "w") as f: + f.write("\n") + init_hydra_config(test_file)