From 093bb9eef19024c9a5ad4d3d986955a08d2025ab Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 25 Mar 2024 14:08:43 +0000 Subject: [PATCH] Keep the last frame when preparing datasets --- lerobot/common/datasets/aloha.py | 20 ++++++++++++-------- lerobot/common/datasets/pusht.py | 24 +++++++++++++----------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 7c0c9d44..27f2750b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -154,16 +154,18 @@ class AlohaExperienceReplay(AbstractExperienceReplay): state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) + # Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the + # last. ep_td = TensorDict( { - ("observation", "state"): state[:-1], - "action": action[:-1], - "episode": torch.tensor([ep_id] * (ep_num_frames - 1)), - "frame_id": torch.arange(0, ep_num_frames - 1, 1), - ("next", "observation", "state"): state[1:], + ("observation", "state"): state, + "action": action, + "episode": torch.tensor([ep_id] * ep_num_frames), + "frame_id": torch.arange(0, ep_num_frames, 1), + ("next", "observation", "state"): torch.cat([state[1:], state[-1].unsqueeze(0)]), # TODO: compute reward and success # ("next", "reward"): reward[1:], - ("next", "done"): done[1:], + ("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]), # ("next", "success"): success[1:], }, batch_size=ep_num_frames - 1, @@ -172,8 +174,10 @@ class AlohaExperienceReplay(AbstractExperienceReplay): for cam in CAMERAS[self.dataset_id]: image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_td["observation", "image", cam] = image[:-1] - ep_td["next", "observation", "image", cam] = image[1:] + ep_td["observation", "image", cam] = image + ep_td["next", "observation", "image", cam] = torch.cat( + [image[1:], image[-1].unsqueeze(0)] + ) if ep_id == 0: # hack to initialize tensordict data structure to store episodes diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index bcbb10b8..4bd617cf 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -192,19 +192,21 @@ class PushtExperienceReplay(AbstractExperienceReplay): # last step of demonstration is considered done done[-1] = True + # Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the + # last. ep_td = TensorDict( { - ("observation", "image"): image[:-1], - ("observation", "state"): agent_pos[:-1], - "action": actions[idx0:idx1][:-1], - "episode": episode_ids[idx0:idx1][:-1], - "frame_id": torch.arange(0, num_frames - 1, 1), - ("next", "observation", "image"): image[1:], - ("next", "observation", "state"): agent_pos[1:], - # TODO: verify that reward and done are aligned with image and agent_pos - ("next", "reward"): reward[1:], - ("next", "done"): done[1:], - ("next", "success"): success[1:], + ("observation", "image"): image, + ("observation", "state"): agent_pos, + "action": actions[idx0:idx1], + "episode": episode_ids[idx0:idx1], + "frame_id": torch.arange(0, num_frames, 1), + ("next", "observation", "image"): torch.cat([image[1:], image[-1].unsqueeze(0)]), + ("next", "observation", "state"): torch.cat([agent_pos[1:], agent_pos[-1].unsqueeze(0)]), + # TODO(rcadene): verify that reward and done are aligned with image and agent_pos + ("next", "reward"): torch.cat([reward[1:], reward[-1].unsqueeze(0)]), + ("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]), + ("next", "success"): torch.cat([success[1:], success[-1].unsqueeze(0)]), }, batch_size=num_frames - 1, )