remove policy is None eval end-to-end tests

This commit is contained in:
Cadene 2024-04-10 15:09:04 +00:00
parent 2186429fa8
commit 8866b22db1
2 changed files with 4 additions and 36 deletions

View File

@ -160,17 +160,6 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
# TODO(aliberts): This takes ~2mn to run, needs to be improved
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py \
# --config lerobot/configs/default.yaml \
# policy=act \
# env=aloha \
# eval_episodes=1 \
# device=cpu
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
@ -197,17 +186,6 @@ jobs:
device=cpu \
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test eval Diffusion on PushT end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=diffusion \
env=pusht \
eval_episodes=1 \
env.episode_length=8 \
device=cpu
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
@ -233,13 +211,3 @@ jobs:
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
env=xarm \
eval_episodes=1 \
device=cpu

View File

@ -57,7 +57,7 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: gym.vector.VectorEnv,
policy,
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
@ -312,12 +312,12 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
logging.info("Making policy.")
policy = make_policy(cfg)
info = eval_policy(
env,
policy=policy,
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
# TODO(rcadene): what should we do with the transform?