Add option for random policy
This commit is contained in:
parent
5a5b190f70
commit
937b2f8cba
|
@ -4,9 +4,9 @@ import hydra
|
|||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
|
@ -14,7 +14,12 @@ from lerobot.common.utils import set_seed
|
|||
|
||||
|
||||
def eval_policy(
|
||||
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
|
||||
env: EnvBase,
|
||||
policy: TensorDictModule = None,
|
||||
num_episodes: int = 10,
|
||||
max_steps: int = 30,
|
||||
save_video: bool = False,
|
||||
video_dir: Path = None,
|
||||
):
|
||||
rewards = []
|
||||
successes = []
|
||||
|
@ -31,7 +36,7 @@ def eval_policy(
|
|||
rendering_callback(env)
|
||||
|
||||
rollout = env.rollout(
|
||||
max_steps=30,
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
|
@ -73,9 +78,10 @@ def eval(cfg: dict):
|
|||
out_keys=["action"],
|
||||
)
|
||||
|
||||
# policy can be None to rollout a random policy
|
||||
metrics = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
policy=policy,
|
||||
num_episodes=10,
|
||||
save_video=True,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
|
|
Loading…
Reference in New Issue