From 9c7649f140fb437471f915423439ba7096da422b Mon Sep 17 00:00:00 2001
From: Alexander Soare <alexander.soare159@gmail.com>
Date: Fri, 23 Aug 2024 12:27:08 +0100
Subject: [PATCH] Make sure `init_hydra_config` does not require any keys
 (#376)

---
 lerobot/common/utils/utils.py |  9 ---------
 lerobot/scripts/eval.py       | 10 ++++++++++
 lerobot/scripts/train.py      | 10 ++++++++++
 tests/test_utils.py           |  9 +++++++++
 4 files changed, 29 insertions(+), 9 deletions(-)

diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py
index a7cb6374..1aa0bc2d 100644
--- a/lerobot/common/utils/utils.py
+++ b/lerobot/common/utils/utils.py
@@ -165,15 +165,6 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
         version_base="1.2",
     )
     cfg = hydra.compose(Path(config_path).stem, overrides)
-    if cfg.eval.batch_size > cfg.eval.n_episodes:
-        raise ValueError(
-            "The eval batch size is greater than the number of eval episodes "
-            f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
-            f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
-            "This might significantly slow down evaluation. To fix this, you should update your command "
-            f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
-            f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
-        )
     return cfg
 
 
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index 980c373c..482af786 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -454,6 +454,16 @@ def main(
     else:
         hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
 
+    if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
+        raise ValueError(
+            "The eval batch size is greater than the number of eval episodes "
+            f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
+            f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
+            "This might significantly slow down evaluation. To fix this, you should update your command "
+            f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
+            f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
+        )
+
     if out_dir is None:
         out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
 
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 2fa7ae80..45807503 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -288,6 +288,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
             "you meant to resume training, please use `resume=true` in your command or yaml configuration."
         )
 
+    if cfg.eval.batch_size > cfg.eval.n_episodes:
+        raise ValueError(
+            "The eval batch size is greater than the number of eval episodes "
+            f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
+            f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
+            "This might significantly slow down evaluation. To fix this, you should update your command "
+            f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
+            f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
+        )
+
     # log metrics to terminal and wandb
     logger = Logger(cfg, out_dir, wandb_job_name=job_name)
 
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d4a8e34a..e5ba2267 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,5 +1,6 @@
 import random
 from typing import Callable
+from uuid import uuid4
 
 import numpy as np
 import pytest
@@ -13,6 +14,7 @@ from lerobot.common.datasets.utils import (
 )
 from lerobot.common.utils.utils import (
     get_global_random_state,
+    init_hydra_config,
     seeded_context,
     set_global_random_state,
     set_global_seed,
@@ -83,3 +85,10 @@ def test_reset_episode_index():
     correct_episode_index = [0, 0, 1, 2, 2, 2]
     dataset = reset_episode_index(dataset)
     assert dataset["episode_index"] == correct_episode_index
+
+
+def test_init_hydra_config_empty():
+    test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
+    with open(test_file, "w") as f:
+        f.write("\n")
+    init_hydra_config(test_file)