Add task field to frame_dict in ReplayBuffer and simplify save_episode calls

- Introduced a new "task" field in frame_dict to meet the requirements of LeRobotDataset.
- Removed task_name parameter from save_episode calls for consistency.
This commit is contained in:
AdilZouitine 2025-03-24 20:28:14 +00:00
parent f483931fc0
commit 5fbbc65869
1 changed files with 5 additions and 3 deletions

View File

@ -515,6 +515,9 @@ class ReplayBuffer:
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
@ -524,7 +527,7 @@ class ReplayBuffer:
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.save_episode()
episode_index += 1
frame_idx_in_episode = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
@ -533,10 +536,9 @@ class ReplayBuffer:
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
lerobot_dataset.save_episode(task=task_name)
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
return lerobot_dataset