Add task field to frame_dict in ReplayBuffer and simplify save_episode calls
- Introduced a new "task" field in frame_dict to meet the requirements of LeRobotDataset. - Removed task_name parameter from save_episode calls for consistency.
This commit is contained in:
parent
f483931fc0
commit
5fbbc65869
|
@ -515,6 +515,9 @@ class ReplayBuffer:
|
||||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||||
|
|
||||||
|
# Add task field which is required by LeRobotDataset
|
||||||
|
frame_dict["task"] = task_name
|
||||||
|
|
||||||
# Add to the dataset's buffer
|
# Add to the dataset's buffer
|
||||||
lerobot_dataset.add_frame(frame_dict)
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
@ -524,7 +527,7 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# If we reached an episode boundary, call save_episode, reset counters
|
# If we reached an episode boundary, call save_episode, reset counters
|
||||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||||
lerobot_dataset.save_episode(task=task_name)
|
lerobot_dataset.save_episode()
|
||||||
episode_index += 1
|
episode_index += 1
|
||||||
frame_idx_in_episode = 0
|
frame_idx_in_episode = 0
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||||
|
@ -533,10 +536,9 @@ class ReplayBuffer:
|
||||||
|
|
||||||
# Save any remaining frames in the buffer
|
# Save any remaining frames in the buffer
|
||||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||||
lerobot_dataset.save_episode(task=task_name)
|
lerobot_dataset.save_episode()
|
||||||
|
|
||||||
lerobot_dataset.stop_image_writer()
|
lerobot_dataset.stop_image_writer()
|
||||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
|
||||||
|
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue