diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index cf71ad2e..91d2cf00 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -224,7 +224,22 @@ def train(cfg: dict, out_dir=None, job_name=None): policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length + + assert ( + len(rollout.batch_size) == 2 + ), "2 dimensions expected: number of env in parallel x max number of steps during rollout" + + num_parallel_env = rollout.batch_size[0] + if num_parallel_env != 1: + # TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests + raise NotImplementedError() + + num_max_steps = rollout.batch_size[1] + assert num_max_steps <= cfg.env.episode_length + + # reshape to have a list of steps to insert into online_buffer + rollout = rollout.reshape(num_parallel_env * num_max_steps) + # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout)