32 lines
785 B
Python
32 lines
785 B
Python
|
from pprint import pprint
|
||
|
|
||
|
from hydra import compose, initialize
|
||
|
from hydra.core.global_hydra import GlobalHydra
|
||
|
from omegaconf import OmegaConf
|
||
|
from omegaconf.dictconfig import DictConfig
|
||
|
|
||
|
CONFIG_DIR = "../lerobot/configs"
|
||
|
DEFAULT_CONFIG = "default"
|
||
|
|
||
|
|
||
|
def config_notebook(
|
||
|
policy: str = "diffusion",
|
||
|
env: str = "pusht",
|
||
|
device: str = "cpu",
|
||
|
config_name=DEFAULT_CONFIG,
|
||
|
config_path=CONFIG_DIR,
|
||
|
print_config: bool = False,
|
||
|
) -> DictConfig:
|
||
|
GlobalHydra.instance().clear()
|
||
|
initialize(config_path=config_path)
|
||
|
overrides = [
|
||
|
f"env={env}",
|
||
|
f"policy={policy}",
|
||
|
f"device={device}",
|
||
|
]
|
||
|
cfg = compose(config_name=config_name, overrides=overrides)
|
||
|
if print_config:
|
||
|
pprint(OmegaConf.to_container(cfg))
|
||
|
|
||
|
return cfg
|