32 lines
785 B
Python
32 lines
785 B
Python
from pprint import pprint
|
|
|
|
from hydra import compose, initialize
|
|
from hydra.core.global_hydra import GlobalHydra
|
|
from omegaconf import OmegaConf
|
|
from omegaconf.dictconfig import DictConfig
|
|
|
|
CONFIG_DIR = "../lerobot/configs"
|
|
DEFAULT_CONFIG = "default"
|
|
|
|
|
|
def config_notebook(
|
|
policy: str = "diffusion",
|
|
env: str = "pusht",
|
|
device: str = "cpu",
|
|
config_name=DEFAULT_CONFIG,
|
|
config_path=CONFIG_DIR,
|
|
print_config: bool = False,
|
|
) -> DictConfig:
|
|
GlobalHydra.instance().clear()
|
|
initialize(config_path=config_path)
|
|
overrides = [
|
|
f"env={env}",
|
|
f"policy={policy}",
|
|
f"device={device}",
|
|
]
|
|
cfg = compose(config_name=config_name, overrides=overrides)
|
|
if print_config:
|
|
pprint(OmegaConf.to_container(cfg))
|
|
|
|
return cfg
|