Keep the last frame when preparing datasets
This commit is contained in:
parent
e41c420a96
commit
093bb9eef1
|
@ -154,16 +154,18 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
|
||||||
|
# last.
|
||||||
ep_td = TensorDict(
|
ep_td = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "state"): state[:-1],
|
("observation", "state"): state,
|
||||||
"action": action[:-1],
|
"action": action,
|
||||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
"episode": torch.tensor([ep_id] * ep_num_frames),
|
||||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
"frame_id": torch.arange(0, ep_num_frames, 1),
|
||||||
("next", "observation", "state"): state[1:],
|
("next", "observation", "state"): torch.cat([state[1:], state[-1].unsqueeze(0)]),
|
||||||
# TODO: compute reward and success
|
# TODO: compute reward and success
|
||||||
# ("next", "reward"): reward[1:],
|
# ("next", "reward"): reward[1:],
|
||||||
("next", "done"): done[1:],
|
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
|
||||||
# ("next", "success"): success[1:],
|
# ("next", "success"): success[1:],
|
||||||
},
|
},
|
||||||
batch_size=ep_num_frames - 1,
|
batch_size=ep_num_frames - 1,
|
||||||
|
@ -172,8 +174,10 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||||
for cam in CAMERAS[self.dataset_id]:
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||||
ep_td["observation", "image", cam] = image[:-1]
|
ep_td["observation", "image", cam] = image
|
||||||
ep_td["next", "observation", "image", cam] = image[1:]
|
ep_td["next", "observation", "image", cam] = torch.cat(
|
||||||
|
[image[1:], image[-1].unsqueeze(0)]
|
||||||
|
)
|
||||||
|
|
||||||
if ep_id == 0:
|
if ep_id == 0:
|
||||||
# hack to initialize tensordict data structure to store episodes
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
|
|
@ -192,19 +192,21 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
|
# Note: for the "next" key we take data[1:] and then append on one more frame as a copy of the
|
||||||
|
# last.
|
||||||
ep_td = TensorDict(
|
ep_td = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image"): image[:-1],
|
("observation", "image"): image,
|
||||||
("observation", "state"): agent_pos[:-1],
|
("observation", "state"): agent_pos,
|
||||||
"action": actions[idx0:idx1][:-1],
|
"action": actions[idx0:idx1],
|
||||||
"episode": episode_ids[idx0:idx1][:-1],
|
"episode": episode_ids[idx0:idx1],
|
||||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
("next", "observation", "image"): image[1:],
|
("next", "observation", "image"): torch.cat([image[1:], image[-1].unsqueeze(0)]),
|
||||||
("next", "observation", "state"): agent_pos[1:],
|
("next", "observation", "state"): torch.cat([agent_pos[1:], agent_pos[-1].unsqueeze(0)]),
|
||||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||||
("next", "reward"): reward[1:],
|
("next", "reward"): torch.cat([reward[1:], reward[-1].unsqueeze(0)]),
|
||||||
("next", "done"): done[1:],
|
("next", "done"): torch.cat([done[1:], done[-1].unsqueeze(0)]),
|
||||||
("next", "success"): success[1:],
|
("next", "success"): torch.cat([success[1:], success[-1].unsqueeze(0)]),
|
||||||
},
|
},
|
||||||
batch_size=num_frames - 1,
|
batch_size=num_frames - 1,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue