Make sure `init_hydra_config` does not require any keys (#376)
This commit is contained in:
parent
a2592a5563
commit
9c7649f140
|
@ -165,15 +165,6 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||||
version_base="1.2",
|
version_base="1.2",
|
||||||
)
|
)
|
||||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
|
||||||
raise ValueError(
|
|
||||||
"The eval batch size is greater than the number of eval episodes "
|
|
||||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
|
||||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
|
||||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
|
||||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
|
||||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
|
||||||
)
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -454,6 +454,16 @@ def main(
|
||||||
else:
|
else:
|
||||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||||
|
|
||||||
|
if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
|
||||||
|
raise ValueError(
|
||||||
|
"The eval batch size is greater than the number of eval episodes "
|
||||||
|
f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
|
||||||
|
f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
|
||||||
|
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||||
|
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
|
||||||
|
f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
|
||||||
|
)
|
||||||
|
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||||
|
|
||||||
|
|
|
@ -288,6 +288,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||||
|
raise ValueError(
|
||||||
|
"The eval batch size is greater than the number of eval episodes "
|
||||||
|
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||||
|
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||||
|
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||||
|
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||||
|
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||||
|
)
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import random
|
import random
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -13,6 +14,7 @@ from lerobot.common.datasets.utils import (
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
get_global_random_state,
|
get_global_random_state,
|
||||||
|
init_hydra_config,
|
||||||
seeded_context,
|
seeded_context,
|
||||||
set_global_random_state,
|
set_global_random_state,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
|
@ -83,3 +85,10 @@ def test_reset_episode_index():
|
||||||
correct_episode_index = [0, 0, 1, 2, 2, 2]
|
correct_episode_index = [0, 0, 1, 2, 2, 2]
|
||||||
dataset = reset_episode_index(dataset)
|
dataset = reset_episode_index(dataset)
|
||||||
assert dataset["episode_index"] == correct_episode_index
|
assert dataset["episode_index"] == correct_episode_index
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_hydra_config_empty():
|
||||||
|
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("\n")
|
||||||
|
init_hydra_config(test_file)
|
||||||
|
|
Loading…
Reference in New Issue