From 9241b5e8302bc2c9fe415b1c1c2f988ead6de746 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 09:52:54 +0100 Subject: [PATCH] pass step as kwarg --- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d676623e..2b8906d7 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -130,7 +130,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step) + action = policy.select_action(observation, step=step) # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 03506f2a..5ff6538d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: