Keep the last frame when preparing datasets

This commit is contained in:
Alexander Soare 2024-03-25 14:08:43 +00:00
parent e41c420a96
commit 093bb9eef1
2 changed files with 25 additions and 19 deletions

View File

@ -154,16 +154,18 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
# last.
ep_td = TensorDict(
{
("observation", "state"): state[:-1],
"action": action[:-1],
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
("next", "observation", "state"): state[1:],
("observation", "state"): state,
"action": action,
"episode": torch.tensor([ep_id] * ep_num_frames),
"frame_id": torch.arange(0, ep_num_frames, 1),
("next", "observation", "state"): torch.cat([state[1:], state[-1].unsqueeze(0)]),
# TODO: compute reward and success
# ("next", "reward"): reward[1:],
("next", "done"): done[1:],
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
# ("next", "success"): success[1:],
},
batch_size=ep_num_frames - 1,
@ -172,8 +174,10 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_td["observation", "image", cam] = image[:-1]
ep_td["next", "observation", "image", cam] = image[1:]
ep_td["observation", "image", cam] = image
ep_td["next", "observation", "image", cam] = torch.cat(
[image[1:], image[-1].unsqueeze(0)]
)
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes

View File

@ -192,19 +192,21 @@ class PushtExperienceReplay(AbstractExperienceReplay):
# last step of demonstration is considered done
done[-1] = True
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
# last.
ep_td = TensorDict(
{
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
"action": actions[idx0:idx1][:-1],
"episode": episode_ids[idx0:idx1][:-1],
"frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): image[1:],
("next", "observation", "state"): agent_pos[1:],
# TODO: verify that reward and done are aligned with image and agent_pos
("next", "reward"): reward[1:],
("next", "done"): done[1:],
("next", "success"): success[1:],
("observation", "image"): image,
("observation", "state"): agent_pos,
"action": actions[idx0:idx1],
"episode": episode_ids[idx0:idx1],
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): torch.cat([image[1:], image[-1].unsqueeze(0)]),
("next", "observation", "state"): torch.cat([agent_pos[1:], agent_pos[-1].unsqueeze(0)]),
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
("next", "reward"): torch.cat([reward[1:], reward[-1].unsqueeze(0)]),
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
("next", "success"): torch.cat([success[1:], success[-1].unsqueeze(0)]),
},
batch_size=num_frames - 1,
)