diff --git a/README.md b/README.md
index 4bba66df..07da0d9a 100644
--- a/README.md
+++ b/README.md
@@ -15,12 +15,20 @@ conda activate lerobot
 python setup.py develop
 ```
 
+## TODO
+
+- [ ] priority update doesnt match FOWM or original paper
+- [ ] self.step=100000 should be updated at every step to adjust to horizon of planner
+- [ ] prefetch replay buffer to speedup training
+- [ ] parallelize env to speedup eval
 
 ## Contribute
 
 **style**
 ```
-isort .
-black .
+isort lerobot
+black lerobot
+isort test
+black test
 pylint lerobot
 ```
diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py
index 1470fceb..8d955072 100644
--- a/lerobot/common/envs/simxarm.py
+++ b/lerobot/common/envs/simxarm.py
@@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
 
     def _format_raw_obs(self, raw_obs):
         if self.from_pixels:
-            camera = self.render(
+            image = self.render(
                 mode="rgb_array", width=self.image_size, height=self.image_size
             )
-            camera = camera.transpose(2, 0, 1)  # (H, W, C) -> (C, H, W)
-            camera = torch.tensor(camera.copy(), dtype=torch.uint8)
+            image = image.transpose(2, 0, 1)  # (H, W, C) -> (C, H, W)
+            image = torch.tensor(image.copy(), dtype=torch.uint8)
 
-            obs = {"camera": camera}
+            obs = {"image": image}
 
             if not self.pixels_only:
-                obs["robot_state"] = torch.tensor(
-                    self._env.robot_state, dtype=torch.float32
-                )
+                obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
         else:
             obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
 
@@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
     def _make_spec(self):
         obs = {}
         if self.from_pixels:
-            obs["camera"] = BoundedTensorSpec(
+            obs["image"] = BoundedTensorSpec(
                 low=0,
                 high=255,
                 shape=(3, self.image_size, self.image_size),
@@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
                 device=self.device,
             )
             if not self.pixels_only:
-                obs["robot_state"] = UnboundedContinuousTensorSpec(
+                obs["state"] = UnboundedContinuousTensorSpec(
                     shape=(len(self._env.robot_state),),
                     dtype=torch.float32,
                     device=self.device,
diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py
index da8638dd..d694a06d 100644
--- a/lerobot/common/tdmpc.py
+++ b/lerobot/common/tdmpc.py
@@ -96,8 +96,7 @@ class TDMPC(nn.Module):
         self.model_target.eval()
         self.batch_size = cfg.batch_size
 
-        # TODO(rcadene): clean
-        self.step = 100000
+        self.step = 0
 
     def state_dict(self):
         """Retrieve state dict of TOLD model, including slow-moving target network."""
@@ -120,8 +119,8 @@ class TDMPC(nn.Module):
     def forward(self, observation, step_count):
         t0 = step_count.item() == 0
         obs = {
-            "rgb": observation["camera"],
-            "state": observation["robot_state"],
+            "rgb": observation["image"],
+            "state": observation["state"],
         }
         return self.act(obs, t0=t0, step=self.step)
 
@@ -298,65 +297,81 @@ class TDMPC(nn.Module):
     def update(self, replay_buffer, step, demo_buffer=None):
         """Main update function. Corresponds to one iteration of the model learning."""
 
-        if demo_buffer is not None:
-            # Update oversampling ratio
-            self.demo_batch_size = int(
-                h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
-            )
-            replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
-            demo_buffer.cfg.batch_size = self.demo_batch_size
+        num_slices = self.cfg.batch_size
+        batch_size = self.cfg.horizon * num_slices
+
+        if demo_buffer is None:
+            demo_batch_size = 0
         else:
-            self.demo_batch_size = 0
+            # Update oversampling ratio
+            demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
+            demo_num_slices = int(demo_pc_batch * self.batch_size)
+            demo_batch_size = self.cfg.horizon * demo_num_slices
+            batch_size -= demo_batch_size
+            num_slices -= demo_num_slices
+            replay_buffer._sampler.num_slices = num_slices
+            demo_buffer._sampler.num_slices = demo_num_slices
+
+            assert demo_batch_size % self.cfg.horizon == 0
+            assert demo_batch_size % demo_num_slices == 0
+
+        assert batch_size % self.cfg.horizon == 0
+        assert batch_size % num_slices == 0
 
         # Sample from interaction dataset
 
-        # to not have to mask
-        # batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
-        batch_size = self.cfg.horizon * self.cfg.batch_size
+        def process_batch(batch, horizon, num_slices):
+            # trajectory t = 256, horizon h = 5
+            # (t h) ... -> h t ...
+            batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
+            batch = batch.to("cuda")
+
+            FIRST_FRAME = 0
+            obs = {
+                "rgb": batch["observation", "image"][FIRST_FRAME].float(),
+                "state": batch["observation", "state"][FIRST_FRAME],
+            }
+            action = batch["action"]
+            next_obses = {
+                "rgb": batch["next", "observation", "image"].float(),
+                "state": batch["next", "observation", "state"],
+            }
+            reward = batch["next", "reward"]
+
+            # TODO(rcadene): rearrange directly in offline dataset
+            if reward.ndim == 2:
+                reward = einops.rearrange(reward, "h t -> h t 1")
+
+            assert reward.ndim == 3
+            assert reward.shape == (horizon, num_slices, 1)
+            # We dont use `batch["next", "done"]` since it only indicates the end of an
+            # episode, but not the end of the trajectory of an episode.
+            # Neither does `batch["next", "terminated"]`
+            done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
+            mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
+
+            idxs = batch["index"][FIRST_FRAME]
+            weights = batch["_weight"][FIRST_FRAME, :, None]
+            return obs, action, next_obses, reward, mask, done, idxs, weights
+
         batch = replay_buffer.sample(batch_size)
-
-        # trajectory t = 256, horizon h = 5
-        # (t h) ... -> h t ...
-        batch = (
-            batch.reshape(self.cfg.batch_size, self.cfg.horizon)
-            .transpose(1, 0)
-            .contiguous()
+        obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
+            batch, self.cfg.horizon, num_slices
         )
-        batch = batch.to("cuda")
-
-        FIRST_FRAME = 0
-        obs = {
-            "rgb": batch["observation", "image"][FIRST_FRAME].float(),
-            "state": batch["observation", "state"][FIRST_FRAME],
-        }
-        action = batch["action"]
-        next_obses = {
-            "rgb": batch["next", "observation", "image"].float(),
-            "state": batch["next", "observation", "state"],
-        }
-        reward = batch["next", "reward"]
-        reward = einops.rearrange(reward, "h t -> h t 1")
-        # We dont use `batch["next", "done"]` since it only indicates the end of an
-        # episode, but not the end of the trajectory of an episode.
-        # Neither does `batch["next", "terminated"]`
-        done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
-        mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
-
-        idxs = batch["frame_id"][FIRST_FRAME]
-        weights = batch["_weight"][FIRST_FRAME, :, None]
 
         # Sample from demonstration dataset
-        if self.demo_batch_size > 0:
+        if demo_batch_size > 0:
+            demo_batch = demo_buffer.sample(demo_batch_size)
             (
                 demo_obs,
-                demo_next_obses,
                 demo_action,
+                demo_next_obses,
                 demo_reward,
                 demo_mask,
                 demo_done,
                 demo_idxs,
                 demo_weights,
-            ) = demo_buffer.sample()
+            ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
 
             if isinstance(obs, dict):
                 obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
@@ -440,9 +455,9 @@ class TDMPC(nn.Module):
             q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
             priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
 
-        self.expectile = h.linear_schedule(self.cfg.expectile, step)
+        expectile = h.linear_schedule(self.cfg.expectile, step)
         v_value_loss = (
-            rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
+            rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
         ).sum(dim=0)
 
         total_loss = (
@@ -464,17 +479,12 @@ class TDMPC(nn.Module):
         if self.cfg.per:
             # Update priorities
             priorities = priority_loss.clamp(max=1e4).detach()
-            # normalize between [0,1] to fit torchrl specification
-            priorities /= 1e4
-            priorities = priorities.clamp(max=1.0)
             replay_buffer.update_priority(
-                idxs[: self.cfg.batch_size],
-                priorities[: self.cfg.batch_size],
+                idxs[:num_slices],
+                priorities[:num_slices],
             )
-            if self.demo_batch_size > 0:
-                demo_buffer.update_priority(
-                    demo_idxs, priorities[self.cfg.batch_size :]
-                )
+            if demo_batch_size > 0:
+                demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
 
         # Update policy + target network
         _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@@ -493,10 +503,12 @@ class TDMPC(nn.Module):
             "weighted_loss": float(weighted_loss.mean().item()),
             "grad_norm": float(grad_norm),
         }
-        for key in ["demo_batch_size", "expectile"]:
-            if hasattr(self, key):
-                metrics[key] = getattr(self, key)
+        # for key in ["demo_batch_size", "expectile"]:
+        #     if hasattr(self, key):
+        metrics["demo_batch_size"] = demo_batch_size
+        metrics["expectile"] = expectile
         metrics.update(value_info)
         metrics.update(pi_update_info)
 
+        self.step = step
         return metrics
diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml
index ce43b293..f1a014aa 100644
--- a/lerobot/configs/default.yaml
+++ b/lerobot/configs/default.yaml
@@ -80,7 +80,7 @@ expectile: 0.9
 A_scaling: 3.0
 
 # offline->online
-offline_steps: ${train_steps}/2
+offline_steps: 25000 # ${train_steps}/2
 pretrained_model_path: ""
 balanced_sampling: true
 demo_schedule: 0.5
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 55c9c0f8..8ae05cda 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -19,6 +19,7 @@ from lerobot.common.logger import Logger
 from lerobot.common.tdmpc import TDMPC
 from lerobot.common.utils import set_seed
 from lerobot.scripts.eval import eval_policy
+from rl.torchrl.collectors.collectors import SyncDataCollector
 
 
 @hydra.main(version_base=None, config_name="default", config_path="../configs")
@@ -29,8 +30,10 @@ def train(cfg: dict):
 
     env = make_env(cfg)
     policy = TDMPC(cfg)
-    # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
-    ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
+    ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
+    policy.step = 25000
+    # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
+    # policy.step = 100000
     policy.load(ckpt_path)
 
     td_policy = TensorDictModule(
@@ -54,7 +57,7 @@ def train(cfg: dict):
         strict_length=False,
     )
 
-    # TODO(rcadene): use PrioritizedReplayBuffer
+    # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
     offline_buffer = SimxarmExperienceReplay(
         dataset_id,
         # download="force",
@@ -68,9 +71,22 @@ def train(cfg: dict):
     index = torch.arange(0, num_steps, 1)
     sampler.extend(index)
 
-    # offline_buffer._storage.device = torch.device("cuda")
-    # offline_buffer._storage._storage.to(torch.device("cuda"))
-    # TODO(rcadene): add online_buffer
+    if cfg.balanced_sampling:
+        online_sampler = PrioritizedSliceSampler(
+            max_capacity=100_000,
+            alpha=0.7,
+            beta=0.9,
+            num_slices=num_traj_per_batch,
+            strict_length=False,
+        )
+
+        online_buffer = TensorDictReplayBuffer(
+            storage=LazyMemmapStorage(100_000),
+            sampler=online_sampler,
+            # batch_size=3,
+            # pin_memory=False,
+            # prefetch=3,
+        )
 
     # Observation encoder
     # Dynamics predictor
@@ -81,59 +97,80 @@ def train(cfg: dict):
 
     L = Logger(cfg.log_dir, cfg)
 
-    episode_idx = 0
+    online_episode_idx = 0
     start_time = time.time()
     step = 0
     last_log_step = 0
     last_save_step = 0
 
+    # TODO(rcadene): remove
+    step = 25000
+
     while step < cfg.train_steps:
         is_offline = True
         num_updates = cfg.episode_length
         _step = step + num_updates
         rollout_metrics = {}
 
-        # if step >= cfg.offline_steps:
-        #     is_offline = False
+        if step >= cfg.offline_steps:
+            is_offline = False
 
-        #     # Collect trajectory
-        #     obs = env.reset()
-        #     episode = Episode(cfg, obs)
-        #     success = False
-        #     while not episode.done:
-        #         action = policy.act(obs, step=step, t0=episode.first)
-        #         obs, reward, done, info = env.step(action.cpu().numpy())
-        #         reward = reward_normalizer(reward)
-        #         mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
-        #         success = info.get('success', False)
-        #         episode += (obs, action, reward, done, mask, success)
-        #     assert len(episode) <= cfg.episode_length
-        #     buffer += episode
-        #     episode_idx += 1
-        #     rollout_metrics = {
-        #         'episode_reward': episode.cumulative_reward,
-        #         'episode_success': float(success),
-        #         'episode_length': len(episode)
-        #     }
-        #     num_updates = len(episode) * cfg.utd
-        #     _step = min(step + len(episode), cfg.train_steps)
+            # TODO: use SyncDataCollector for that?
+            rollout = env.rollout(
+                max_steps=cfg.episode_length,
+                policy=td_policy,
+            )
+            assert len(rollout) <= cfg.episode_length
+            rollout["episode"] = torch.tensor(
+                [online_episode_idx] * len(rollout), dtype=torch.int
+            )
+            online_buffer.extend(rollout)
+
+            # Collect trajectory
+            # obs = env.reset()
+            # episode = Episode(cfg, obs)
+            # success = False
+            # while not episode.done:
+            #     action = policy.act(obs, step=step, t0=episode.first)
+            #     obs, reward, done, info = env.step(action.cpu().numpy())
+            #     reward = reward_normalizer(reward)
+            #     mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
+            #     success = info.get('success', False)
+            #     episode += (obs, action, reward, done, mask, success)
+
+            ep_reward = rollout["next", "reward"].sum()
+            ep_success = rollout["next", "success"].any()
+
+            online_episode_idx += 1
+            rollout_metrics = {
+                # 'episode_reward': episode.cumulative_reward,
+                # 'episode_success': float(success),
+                # 'episode_length': len(episode)
+                "avg_reward": np.nanmean(ep_reward),
+                "pc_success": np.nanmean(ep_success) * 100,
+            }
+            num_updates = len(rollout) * cfg.utd
+            _step = min(step + len(rollout), cfg.train_steps)
 
         # Update model
         train_metrics = {}
         if is_offline:
             for i in range(num_updates):
                 train_metrics.update(policy.update(offline_buffer, step + i))
-        # else:
-        #     for i in range(num_updates):
-        #         train_metrics.update(
-        #             policy.update(buffer, step + i // cfg.utd,
-        #                          demo_buffer=offline_buffer if cfg.balanced_sampling else None)
-        #         )
+        else:
+            for i in range(num_updates):
+                train_metrics.update(
+                    policy.update(
+                        online_buffer,
+                        step + i // cfg.utd,
+                        demo_buffer=offline_buffer if cfg.balanced_sampling else None,
+                    )
+                )
 
         # Log training metrics
         env_step = int(_step * cfg.action_repeat)
         common_metrics = {
-            "episode": episode_idx,
+            "episode": online_episode_idx,
             "step": _step,
             "env_step": env_step,
             "total_time": time.time() - start_time,