remove policy is None eval end-to-end tests
This commit is contained in:
parent
2186429fa8
commit
8866b22db1
|
@ -160,17 +160,6 @@ jobs:
|
|||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||
|
||||
# TODO(aliberts): This takes ~2mn to run, needs to be improved
|
||||
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
|
||||
# run: |
|
||||
# source .venv/bin/activate
|
||||
# python lerobot/scripts/eval.py \
|
||||
# --config lerobot/configs/default.yaml \
|
||||
# policy=act \
|
||||
# env=aloha \
|
||||
# eval_episodes=1 \
|
||||
# device=cpu
|
||||
|
||||
- name: Test train Diffusion on PushT end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
|
@ -197,17 +186,6 @@ jobs:
|
|||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||
|
||||
- name: Test eval Diffusion on PushT end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu
|
||||
|
||||
- name: Test train TDMPC on Simxarm end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
|
@ -233,13 +211,3 @@ jobs:
|
|||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||
|
||||
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
eval_episodes=1 \
|
||||
device=cpu
|
||||
|
|
|
@ -57,7 +57,7 @@ def write_video(video_path, stacked_frames, fps):
|
|||
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy,
|
||||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
|
@ -312,12 +312,12 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
# when policy is None, rollout a random policy
|
||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(cfg)
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy=policy,
|
||||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
# TODO(rcadene): what should we do with the transform?
|
||||
|
|
Loading…
Reference in New Issue