This commit is contained in:
Simon Alibert 2024-03-22 00:44:02 +01:00
parent cfbbb4e80a
commit bc36fefa8e
4 changed files with 115 additions and 102 deletions

View File

@ -1,4 +1,5 @@
# ruff: noqa
from pathlib import Path
from pprint import pprint
from hydra import compose, initialize
@ -16,6 +17,7 @@ def config_notebook(
device: str = "cpu",
config_name=DEFAULT_CONFIG,
config_path=CONFIG_DIR,
pretrained_model_path: str = None,
print_config: bool = False,
) -> DictConfig:
GlobalHydra.instance().clear()
@ -24,6 +26,9 @@ def config_notebook(
f"env={env}",
f"policy={policy}",
f"device={device}",
f"policy.pretrained_model_path={pretrained_model_path}",
f"eval_episodes=1",
f"env.episode_length=200",
]
cfg = compose(config_name=config_name, overrides=overrides)
if print_config:

View File

@ -2,131 +2,77 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/tensordict/_pytree.py:147: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" register_pytree_node(\n",
"/Users/simon/projects/lerobot/examples/notebook_utils.py:21: UserWarning: \n",
"The version_base parameter is not specified.\n",
"Please specify a compatability version level, or None.\n",
"Will assume defaults for version 1.1\n",
" initialize(config_path=config_path)\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"from huggingface_hub import snapshot_download\n",
"from hydra import compose, initialize\n",
"from hydra.core.global_hydra import GlobalHydra\n",
"from IPython.display import Video\n",
"from omegaconf import OmegaConf\n",
"from omegaconf.dictconfig import DictConfig\n",
"\n",
"from examples.notebook_utils import config_notebook\n",
"from examples.pretrained_script import download_eval_pretrained\n",
"from lerobot.scripts.eval import eval\n",
"\n",
"# Select policy and env\n",
"POLICY = \"act\" # \"tdmpc\" | \"diffusion\"\n",
"ENV = \"aloha\" # \"pusht\" | \"simxarm\"\n",
"POLICY = \"diffusion\" # \"tdmpc\" | \"diffusion\"\n",
"ENV = \"pusht\" # \"pusht\" | \"simxarm\"\n",
"\n",
"# Select device\n",
"DEVICE = \"cpu\" # \"cuda\" | \"mps\"\n",
"DEVICE = \"mps\" # \"cuda\" | \"mps\"\n",
"\n",
"# Generated videos will be written here\n",
"OUT_DIR = Path(\"./outputs\")\n",
"OUT_EXAMPLE = OUT_DIR / \"eval\" / \"eval_episode_0.mp4\"\n",
"\n",
"PRETRAINED_REPO = \"lerobot/diffusion_policy_pusht_image\"\n",
"pretrained_folder = Path(snapshot_download(repo_id=PRETRAINED_REPO, repo_type=\"model\", revision=\"v1.0\"))\n",
"pretrained_model_path = pretrained_folder / \"model.pt\"\n",
"\n",
"cfg_path = pretrained_folder / \"config.yaml\"\n",
"GlobalHydra.instance().clear()\n",
"\n",
"print(pretrained_folder)\n",
"\n",
"initialize(config_path=\"../../../.cache/huggingface/hub/models--lerobot--diffusion_policy_pusht_image/snapshots/163d168f5c193c356b82e3bf6bbf5b4eeaa780d7\")\n",
"overrides = [\n",
" f\"env={ENV}\",\n",
" f\"policy={POLICY}\",\n",
" f\"device={DEVICE}\",\n",
" f\"+policy.pretrained_model_path={pretrained_model_path}\",\n",
" f\"eval_episodes=1\",\n",
" f\"+env.episode_length=200\",\n",
"]\n",
"cfg = compose(config_name=\"config\", overrides=overrides)\n",
"pprint(OmegaConf.to_container(cfg))\n",
"# Setup config\n",
"cfg = config_notebook(policy=POLICY, env=ENV, device=DEVICE, print_config=False)"
"#cfg = config_notebook(cfg_path, policy=POLICY, env=ENV, device=DEVICE, print_config=False, pretrained_model_path=pretrained_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING 2024-03-21 23:15:20 mon/utils.py:20 Using CPU, this will be slow.\n",
"INFO 2024-03-21 23:15:20 on/logger.py:10 \u001b[1m\u001b[33mOutput dir:\u001b[0m outputs\n",
"INFO 2024-03-21 23:15:20 pts/eval.py:142 make_offline_buffer\n",
"INFO 2024-03-21 23:15:20 s/factory.py:44 use prioritized sampler for offline dataset\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "707a0e9b2ca8403b86a3841f770a7b71",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 31 files: 0%| | 0/31 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO 2024-03-21 23:15:20 pts/eval.py:145 make_env\n",
"INFO 2024-03-21 23:15:20 /__init__.py:88 MUJOCO_GL is not set, so an OpenGL backend will be chosen automatically.\n",
"INFO 2024-03-21 23:15:20 /__init__.py:96 Successfully imported OpenGL backend: %s\n",
"INFO 2024-03-21 23:15:20 /__init__.py:31 MuJoCo library version is: %s\n",
"WARNING 2024-03-21 23:15:21 loha/env.py:299 Aloha env is not seeded\n",
"/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n",
"INFO 2024-03-21 23:15:21 ct/policy.py:52 KL Weight 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 83.92M\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]WARNING 2024-03-21 23:15:22 loha/env.py:299 Aloha env is not seeded\n",
"WARNING 2024-03-21 23:15:22 loha/env.py:299 Aloha env is not seeded\n",
" 0%| | 0/1 [00:00<?, ?it/s]\n"
]
},
{
"ename": "RuntimeError",
"evalue": "output with shape [14] doesn't match the broadcast shape [1, 14]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43meval\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mOUT_DIR\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m Video(OUT_EXAMPLE, embed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/projects/lerobot/lerobot/scripts/eval.py:167\u001b[0m, in \u001b[0;36meval\u001b[0;34m(cfg, out_dir)\u001b[0m\n\u001b[1;32m 150\u001b[0m policy \u001b[38;5;241m=\u001b[39m TensorDictModule(\n\u001b[1;32m 151\u001b[0m policy,\n\u001b[1;32m 152\u001b[0m in_keys\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobservation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep_count\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 153\u001b[0m out_keys\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maction\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m \u001b[38;5;66;03m# TODO(aliberts, Cadene): fetch pretrained model from HF hub\u001b[39;00m\n\u001b[1;32m 156\u001b[0m \u001b[38;5;66;03m# if cfg.policy.pretrained_model_path:\u001b[39;00m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;66;03m# policy = make_policy(cfg)\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;66;03m# # when policy is None, rollout a random policy\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;66;03m# policy = None\u001b[39;00m\n\u001b[0;32m--> 167\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[43meval_policy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[43m \u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[43m \u001b[49m\u001b[43mpolicy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpolicy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 170\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_video\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[43m \u001b[49m\u001b[43mvideo_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_dir\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43meval\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[43m \u001b[49m\u001b[43mfps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 173\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepisode_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_episodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meval_episodes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;28mprint\u001b[39m(metrics)\n\u001b[1;32m 178\u001b[0m logging\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEnd of eval\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/projects/lerobot/lerobot/scripts/eval.py:61\u001b[0m, in \u001b[0;36meval_policy\u001b[0;34m(env, policy, num_episodes, max_steps, save_video, video_dir, fps, return_first_video)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m policy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m policy\u001b[38;5;241m.\u001b[39mclear_action_queue()\n\u001b[0;32m---> 61\u001b[0m rollout \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrollout\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mpolicy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpolicy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mauto_cast_to_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaybe_render_frame\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m \u001b[49m\u001b[43mbreak_when_any_done\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;66;03m# Figure out where in each rollout sequence the first done condition was encountered (results after this won't\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# be included).\u001b[39;00m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;66;03m# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.\u001b[39;00m\n\u001b[1;32m 72\u001b[0m rollout_steps \u001b[38;5;241m=\u001b[39m rollout[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnext\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdone\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchrl/envs/common.py:2388\u001b[0m, in \u001b[0;36mEnvBase.rollout\u001b[0;34m(self, max_steps, policy, callback, auto_reset, auto_cast_to_device, break_when_any_done, return_contiguous, tensordict, out)\u001b[0m\n\u001b[1;32m 2384\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensordict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2385\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 2386\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtensordict cannot be provided when auto_reset is True\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2387\u001b[0m )\n\u001b[0;32m-> 2388\u001b[0m tensordict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2389\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m tensordict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2390\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtensordict must be provided when auto_reset is False\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchrl/envs/common.py:2068\u001b[0m, in \u001b[0;36mEnvBase.reset\u001b[0;34m(self, tensordict, **kwargs)\u001b[0m\n\u001b[1;32m 2065\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensordict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2066\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_assert_tensordict_shape(tensordict)\n\u001b[0;32m-> 2068\u001b[0m tensordict_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_reset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensordict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2069\u001b[0m \u001b[38;5;66;03m# We assume that this is done properly\u001b[39;00m\n\u001b[1;32m 2070\u001b[0m \u001b[38;5;66;03m# if reset.device != self.device:\u001b[39;00m\n\u001b[1;32m 2071\u001b[0m \u001b[38;5;66;03m# reset = reset.to(self.device, non_blocking=True)\u001b[39;00m\n\u001b[1;32m 2072\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tensordict_reset \u001b[38;5;129;01mis\u001b[39;00m tensordict:\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchrl/envs/batched_envs.py:58\u001b[0m, in \u001b[0;36m_check_start.<locals>.decorated_fun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m, ParallelEnv):\n\u001b[1;32m 57\u001b[0m _check_for_faulty_process(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_workers)\n\u001b[0;32m---> 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/torchrl/envs/batched_envs.py:796\u001b[0m, in \u001b[0;36mSerialEnv._reset\u001b[0;34m(self, tensordict, **kwargs)\u001b[0m\n\u001b[1;32m 793\u001b[0m tensordict_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 795\u001b[0m _td \u001b[38;5;241m=\u001b[39m _env\u001b[38;5;241m.\u001b[39mreset(tensordict\u001b[38;5;241m=\u001b[39mtensordict_, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 796\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared_tensordicts\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate_\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 797\u001b[0m \u001b[43m \u001b[49m\u001b[43m_td\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 798\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeys_to_update\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selected_reset_keys_filt\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 799\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 800\u001b[0m selected_output_keys \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_selected_reset_keys_filt\n\u001b[1;32m 801\u001b[0m device \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/tensordict/base.py:2717\u001b[0m, in \u001b[0;36mTensorDictBase.update_\u001b[0;34m(self, input_dict_or_td, clone, keys_to_update)\u001b[0m\n\u001b[1;32m 2712\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtensordict\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorDict\n\u001b[1;32m 2714\u001b[0m input_dict_or_td \u001b[38;5;241m=\u001b[39m TensorDict\u001b[38;5;241m.\u001b[39mfrom_dict(\n\u001b[1;32m 2715\u001b[0m input_dict_or_td, batch_dims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_dims\n\u001b[1;32m 2716\u001b[0m )\n\u001b[0;32m-> 2717\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply_nest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2718\u001b[0m \u001b[43m \u001b[49m\u001b[43minplace_update\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2719\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_dict_or_td\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2720\u001b[0m \u001b[43m \u001b[49m\u001b[43mnested_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2721\u001b[0m \u001b[43m \u001b[49m\u001b[43mdefault\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2722\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilter_empty\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2723\u001b[0m \u001b[43m \u001b[49m\u001b[43mnamed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnamed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2724\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_leaf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_leaf_nontensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2725\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2726\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/tensordict/_td.py:720\u001b[0m, in \u001b[0;36mTensorDict._apply_nest\u001b[0;34m(self, fn, batch_size, device, names, inplace, checked, call_on_nested, default, named, nested_keys, prefix, filter_empty, is_leaf, *others, **constructor_kwargs)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 716\u001b[0m _others \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 717\u001b[0m _other\u001b[38;5;241m.\u001b[39m_get_str(key, default\u001b[38;5;241m=\u001b[39mNO_DEFAULT) \u001b[38;5;28;01mfor\u001b[39;00m _other \u001b[38;5;129;01min\u001b[39;00m others\n\u001b[1;32m 718\u001b[0m ]\n\u001b[0;32m--> 720\u001b[0m item_trsf \u001b[38;5;241m=\u001b[39m \u001b[43mitem\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply_nest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 721\u001b[0m \u001b[43m \u001b[49m\u001b[43mfn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 722\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_others\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 723\u001b[0m \u001b[43m \u001b[49m\u001b[43minplace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minplace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 724\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 725\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 726\u001b[0m \u001b[43m \u001b[49m\u001b[43mchecked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchecked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 727\u001b[0m \u001b[43m \u001b[49m\u001b[43mnamed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnamed\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 728\u001b[0m \u001b[43m \u001b[49m\u001b[43mnested_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnested_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 729\u001b[0m \u001b[43m \u001b[49m\u001b[43mdefault\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdefault\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 730\u001b[0m \u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprefix\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 731\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilter_empty\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilter_empty\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 732\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_leaf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_leaf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 733\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconstructor_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 734\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 736\u001b[0m _others \u001b[38;5;241m=\u001b[39m [_other\u001b[38;5;241m.\u001b[39m_get_str(key, default\u001b[38;5;241m=\u001b[39mdefault) \u001b[38;5;28;01mfor\u001b[39;00m _other \u001b[38;5;129;01min\u001b[39;00m others]\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/tensordict/_td.py:739\u001b[0m, in \u001b[0;36mTensorDict._apply_nest\u001b[0;34m(self, fn, batch_size, device, names, inplace, checked, call_on_nested, default, named, nested_keys, prefix, filter_empty, is_leaf, *others, **constructor_kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m named:\n\u001b[1;32m 738\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m nested_keys:\n\u001b[0;32m--> 739\u001b[0m item_trsf \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43munravel_key\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprefix\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_others\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 740\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 741\u001b[0m item_trsf \u001b[38;5;241m=\u001b[39m fn(key, item, \u001b[38;5;241m*\u001b[39m_others)\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/lerobot/lib/python3.10/site-packages/tensordict/base.py:2701\u001b[0m, in \u001b[0;36mTensorDictBase.update_.<locals>.inplace_update\u001b[0;34m(name, dest, source)\u001b[0m\n\u001b[1;32m 2699\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m keys_to_update:\n\u001b[1;32m 2700\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;241m==\u001b[39m name[: \u001b[38;5;28mlen\u001b[39m(key)]:\n\u001b[0;32m-> 2701\u001b[0m \u001b[43mdest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy_\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: output with shape [14] doesn't match the broadcast shape [1, 14]"
]
}
],
"outputs": [],
"source": [
"eval(cfg, out_dir=OUT_DIR)\n",
"# eval(cfg, out_dir=OUT_DIR)\n",
"download_eval_pretrained(OUT_DIR, cfg)\n",
"Video(OUT_EXAMPLE, embed=True)"
]
}

View File

@ -0,0 +1,62 @@
import logging
from pathlib import Path
import torch
from tensordict.nn import TensorDictModule
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
def download_eval_pretrained(out_dir, cfg):
if out_dir is None:
raise NotImplementedError()
init_logging()
# Check device is available
get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
log_output_dir(out_dir)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
else:
# when policy is None, rollout a random policy
policy = None
metrics = eval_policy(
env,
policy=policy,
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)
logging.info("End of eval")
if __name__ == "__main__":
download_eval_pretrained()

View File

@ -202,7 +202,7 @@ class DiffusionPolicy(AbstractPolicy):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
d = torch.load(fp, map_location=torch.device(self.device))
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)