From 5fbbc65869227ba49aaaba87a659ce0998484b1e Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 24 Mar 2025 20:28:14 +0000 Subject: [PATCH] Add task field to frame_dict in ReplayBuffer and simplify save_episode calls - Introduced a new "task" field in frame_dict to meet the requirements of LeRobotDataset. - Removed task_name parameter from save_episode calls for consistency. --- lerobot/scripts/server/buffer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index e10ffbdf..0e25253f 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -515,6 +515,9 @@ class ReplayBuffer: frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + + # Add task field which is required by LeRobotDataset + frame_dict["task"] = task_name # Add to the dataset's buffer lerobot_dataset.add_frame(frame_dict) @@ -524,7 +527,7 @@ class ReplayBuffer: # If we reached an episode boundary, call save_episode, reset counters if self.dones[actual_idx] or self.truncateds[actual_idx]: - lerobot_dataset.save_episode(task=task_name) + lerobot_dataset.save_episode() episode_index += 1 frame_idx_in_episode = 0 lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( @@ -533,10 +536,9 @@ class ReplayBuffer: # Save any remaining frames in the buffer if lerobot_dataset.episode_buffer["size"] > 0: - lerobot_dataset.save_episode(task=task_name) + lerobot_dataset.save_episode() lerobot_dataset.stop_image_writer() - lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False) return lerobot_dataset