Keep the last frame when preparing datasets

This commit is contained in:
Alexander Soare 2024-03-25 14:08:43 +00:00
parent e41c420a96
commit 093bb9eef1
2 changed files with 25 additions and 19 deletions

View File

@ -154,16 +154,18 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
state = torch.from_numpy(ep["/observations/qpos"][:]) state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:]) action = torch.from_numpy(ep["/action"][:])
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
# last.
ep_td = TensorDict( ep_td = TensorDict(
{ {
("observation", "state"): state[:-1], ("observation", "state"): state,
"action": action[:-1], "action": action,
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)), "episode": torch.tensor([ep_id] * ep_num_frames),
"frame_id": torch.arange(0, ep_num_frames - 1, 1), "frame_id": torch.arange(0, ep_num_frames, 1),
("next", "observation", "state"): state[1:], ("next", "observation", "state"): torch.cat([state[1:], state[-1].unsqueeze(0)]),
# TODO: compute reward and success # TODO: compute reward and success
# ("next", "reward"): reward[1:], # ("next", "reward"): reward[1:],
("next", "done"): done[1:], ("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
# ("next", "success"): success[1:], # ("next", "success"): success[1:],
}, },
batch_size=ep_num_frames - 1, batch_size=ep_num_frames - 1,
@ -172,8 +174,10 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
for cam in CAMERAS[self.dataset_id]: for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous() image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_td["observation", "image", cam] = image[:-1] ep_td["observation", "image", cam] = image
ep_td["next", "observation", "image", cam] = image[1:] ep_td["next", "observation", "image", cam] = torch.cat(
[image[1:], image[-1].unsqueeze(0)]
)
if ep_id == 0: if ep_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes

View File

@ -192,19 +192,21 @@ class PushtExperienceReplay(AbstractExperienceReplay):
# last step of demonstration is considered done # last step of demonstration is considered done
done[-1] = True done[-1] = True
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
# last.
ep_td = TensorDict( ep_td = TensorDict(
{ {
("observation", "image"): image[:-1], ("observation", "image"): image,
("observation", "state"): agent_pos[:-1], ("observation", "state"): agent_pos,
"action": actions[idx0:idx1][:-1], "action": actions[idx0:idx1],
"episode": episode_ids[idx0:idx1][:-1], "episode": episode_ids[idx0:idx1],
"frame_id": torch.arange(0, num_frames - 1, 1), "frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): image[1:], ("next", "observation", "image"): torch.cat([image[1:], image[-1].unsqueeze(0)]),
("next", "observation", "state"): agent_pos[1:], ("next", "observation", "state"): torch.cat([agent_pos[1:], agent_pos[-1].unsqueeze(0)]),
# TODO: verify that reward and done are aligned with image and agent_pos # TODO(rcadene): verify that reward and done are aligned with image and agent_pos
("next", "reward"): reward[1:], ("next", "reward"): torch.cat([reward[1:], reward[-1].unsqueeze(0)]),
("next", "done"): done[1:], ("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
("next", "success"): success[1:], ("next", "success"): torch.cat([success[1:], success[-1].unsqueeze(0)]),
}, },
batch_size=num_frames - 1, batch_size=num_frames - 1,
) )