From 2ccf89d78c32a9beaed716be6310a8239d78e6de Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:47:25 +0100 Subject: [PATCH 1/2] try fix tests --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3411e11..a86193b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,8 @@ jobs: device=cpu \ save_model=true \ save_freq=2 \ - horizon=20 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ policy.batch_size=2 \ hydra.run.dir=tests/outputs/act/ From 9241b5e8302bc2c9fe415b1c1c2f988ead6de746 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:52:54 +0100 Subject: [PATCH 2/2] pass step as kwarg --- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d676623e..2b8906d7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -130,7 +130,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step) + action = policy.select_action(observation, step=step) # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 03506f2a..5ff6538d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: