This commit is contained in:
Cadene 2024-03-22 00:13:39 +00:00
parent 70d7b99d09
commit b1e4b7967e
1 changed files with 7 additions and 11 deletions

View File

@ -1,14 +1,9 @@
import logging
from omegaconf import OmegaConf
from pathlib import Path
from lerobot.scripts.eval import eval_policy
from huggingface_hub import snapshot_download
import logging import logging
from pathlib import Path from pathlib import Path
import torch import torch
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
@ -16,15 +11,16 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
folder = Path(snapshot_download('lerobot/diffusion_policy_pusht_image', revision="v1.0")) folder = Path(snapshot_download("lerobot/diffusion_policy_pusht_image", revision="v1.0"))
cfg = OmegaConf.load(folder / "config.yaml") cfg = OmegaConf.load(folder / "config.yaml")
cfg.policy.pretrained_model_path = folder / "model.pt" cfg.policy.pretrained_model_path = folder / "model.pt"
cfg.eval_episodes = 1 cfg.eval_episodes = 1
cfg.episode_length = 50 cfg.episode_length = 50
cfg.device = "cpu" # cfg.device = "cpu"
out_dir = "test" out_dir = "tmp/"
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()