From be6364f1093be9fbbbcd358222ea5c03c97246d9 Mon Sep 17 00:00:00 2001
From: Cadene <re.cadene@gmail.com>
Date: Sun, 24 Mar 2024 23:10:16 +0000
Subject: [PATCH] fix, it's training now!

---
 lerobot/scripts/train.py | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index cf71ad2e..91d2cf00 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -224,7 +224,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
                 policy=td_policy,
                 auto_cast_to_device=True,
             )
-        assert len(rollout) <= cfg.env.episode_length
+
+        assert (
+            len(rollout.batch_size) == 2
+        ), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
+
+        num_parallel_env = rollout.batch_size[0]
+        if num_parallel_env != 1:
+            # TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests
+            raise NotImplementedError()
+
+        num_max_steps = rollout.batch_size[1]
+        assert num_max_steps <= cfg.env.episode_length
+
+        # reshape to have a list of steps to insert into online_buffer
+        rollout = rollout.reshape(num_parallel_env * num_max_steps)
+
         # set same episode index for all time steps contained in this rollout
         rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
         online_buffer.extend(rollout)