fix bus error

This commit is contained in:
Remi Cadene 2024-03-01 14:20:55 +00:00
parent fa7f473142
commit b5a2f460ea
1 changed files with 11 additions and 6 deletions

View File

@ -198,7 +198,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# load # load
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = dataset_dict.get_episode_idxs() episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0] total_frames = dataset_dict["action"].shape[0]
assert len( assert len(
@ -209,6 +209,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle) goal_body = get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
idx0 = 0 idx0 = 0
idxtd = 0 idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)): for episode_id in tqdm.tqdm(range(num_episodes)):
@ -218,10 +223,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
assert (episode_ids[idx0:idx1] == episode_id).all() assert (episode_ids[idx0:idx1] == episode_id).all()
image = torch.from_numpy(dataset_dict["img"][idx0:idx1]) image = imgs[idx0:idx1]
image = einops.rearrange(image, "b h w c -> b c h w")
state = torch.from_numpy(dataset_dict["state"][idx0:idx1]) state = states[idx0:idx1]
agent_pos = state[:, :2] agent_pos = state[:, :2]
block_pos = state[:, 2:4] block_pos = state[:, 2:4]
block_angle = state[:, 4] block_angle = state[:, 4]
@ -255,12 +259,13 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# last step of demonstration is considered done # last step of demonstration is considered done
done[-1] = True done[-1] = True
print("before " + """episode = TensorDict(""")
episode = TensorDict( episode = TensorDict(
{ {
("observation", "image"): image[:-1], ("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1], ("observation", "state"): agent_pos[:-1],
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1], "action": actions[idx0:idx1][:-1],
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1], "episode": episode_ids[idx0:idx1][:-1],
"frame_id": torch.arange(0, num_frames - 1, 1), "frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): image[1:], ("next", "observation", "image"): image[1:],
("next", "observation", "state"): agent_pos[1:], ("next", "observation", "state"): agent_pos[1:],