Add task field to frame_dict in ReplayBuffer and simplify save_episode calls
- Introduced a new "task" field in frame_dict to meet the requirements of LeRobotDataset. - Removed task_name parameter from save_episode calls for consistency.
This commit is contained in:
parent
f483931fc0
commit
5fbbc65869
|
@ -515,6 +515,9 @@ class ReplayBuffer:
|
|||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add task field which is required by LeRobotDataset
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
@ -524,7 +527,7 @@ class ReplayBuffer:
|
|||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
lerobot_dataset.save_episode()
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
|
@ -533,10 +536,9 @@ class ReplayBuffer:
|
|||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
|
Loading…
Reference in New Issue