Keep the last frame when preparing datasets
This commit is contained in:
parent
e41c420a96
commit
093bb9eef1
|
@ -154,16 +154,18 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
|
||||
# last.
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
("observation", "state"): state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([ep_id] * ep_num_frames),
|
||||
"frame_id": torch.arange(0, ep_num_frames, 1),
|
||||
("next", "observation", "state"): torch.cat([state[1:], state[-1].unsqueeze(0)]),
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
|
@ -172,8 +174,10 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
ep_td["observation", "image", cam] = image
|
||||
ep_td["next", "observation", "image", cam] = torch.cat(
|
||||
[image[1:], image[-1].unsqueeze(0)]
|
||||
)
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
|
|
|
@ -192,19 +192,21 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
|
||||
# last.
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": episode_ids[idx0:idx1],
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): torch.cat([image[1:], image[-1].unsqueeze(0)]),
|
||||
("next", "observation", "state"): torch.cat([agent_pos[1:], agent_pos[-1].unsqueeze(0)]),
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): torch.cat([reward[1:], reward[-1].unsqueeze(0)]),
|
||||
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
|
||||
("next", "success"): torch.cat([success[1:], success[-1].unsqueeze(0)]),
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue