lerobot/examples/notebook_utils.py

66 lines
1.6 KiB
Python

# ruff: noqa
from pathlib import Path
from pprint import pprint
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
CONFIG_DIR = "../lerobot/configs"
DEFAULT_CONFIG = "default"
def config_notebook(
policy: str = "diffusion",
env: str = "pusht",
device: str = "cpu",
config_name=DEFAULT_CONFIG,
config_path=CONFIG_DIR,
pretrained_model_path: str = None,
print_config: bool = False,
) -> DictConfig:
GlobalHydra.instance().clear()
initialize(config_path=config_path)
overrides = [
f"env={env}",
f"policy={policy}",
f"device={device}",
f"policy.pretrained_model_path={pretrained_model_path}",
f"eval_episodes=1",
f"env.episode_length=200",
]
cfg = compose(config_name=config_name, overrides=overrides)
if print_config:
pprint(OmegaConf.to_container(cfg))
return cfg
def notebook():
"""tmp"""
from pathlib import Path
from examples.notebook_utils import config_notebook
from lerobot.scripts.eval import eval
# Select policy and env
POLICY = "act" # "tdmpc" | "diffusion"
ENV = "aloha" # "pusht" | "simxarm"
# Select device
DEVICE = "mps" # "cuda" | "mps"
# Generated videos will be written here
OUT_DIR = Path("./outputs")
OUT_EXAMPLE = OUT_DIR / "eval" / "eval_episode_0.mp4"
# Setup config
cfg = config_notebook(policy=POLICY, env=ENV, device=DEVICE, print_config=False)
eval(cfg, out_dir=OUT_DIR)
if __name__ == "__main__":
notebook()