diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index e9d57cba..839c12bb 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -33,6 +33,7 @@ def eval_policy( fps: int = 15, return_first_video: bool = False, ): + policy.eval() start = time.time() sum_rewards = [] max_rewards = [] diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 579f5a58..2c7bb575 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -190,6 +190,7 @@ def train(cfg: dict, out_dir=None, job_name=None): if offline_step == 0: logging.info("Start offline training on a fixed dataset") # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? + policy.train() train_info = policy.update(offline_buffer, step) if step % cfg.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)