fix bus error
This commit is contained in:
parent
fa7f473142
commit
b5a2f460ea
|
@ -198,7 +198,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
# load
|
# load
|
||||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
episode_ids = dataset_dict.get_episode_idxs()
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
total_frames = dataset_dict["action"].shape[0]
|
total_frames = dataset_dict["action"].shape[0]
|
||||||
assert len(
|
assert len(
|
||||||
|
@ -209,6 +209,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(dataset_dict["img"])
|
||||||
|
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||||
|
states = torch.from_numpy(dataset_dict["state"])
|
||||||
|
actions = torch.from_numpy(dataset_dict["action"])
|
||||||
|
|
||||||
idx0 = 0
|
idx0 = 0
|
||||||
idxtd = 0
|
idxtd = 0
|
||||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
@ -218,10 +223,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||||
|
|
||||||
image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
|
image = imgs[idx0:idx1]
|
||||||
image = einops.rearrange(image, "b h w c -> b c h w")
|
|
||||||
|
|
||||||
state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
|
state = states[idx0:idx1]
|
||||||
agent_pos = state[:, :2]
|
agent_pos = state[:, :2]
|
||||||
block_pos = state[:, 2:4]
|
block_pos = state[:, 2:4]
|
||||||
block_angle = state[:, 4]
|
block_angle = state[:, 4]
|
||||||
|
@ -255,12 +259,13 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
|
print("before " + """episode = TensorDict(""")
|
||||||
episode = TensorDict(
|
episode = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image"): image[:-1],
|
("observation", "image"): image[:-1],
|
||||||
("observation", "state"): agent_pos[:-1],
|
("observation", "state"): agent_pos[:-1],
|
||||||
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
|
"action": actions[idx0:idx1][:-1],
|
||||||
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
|
"episode": episode_ids[idx0:idx1][:-1],
|
||||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||||
("next", "observation", "image"): image[1:],
|
("next", "observation", "image"): image[1:],
|
||||||
("next", "observation", "state"): agent_pos[1:],
|
("next", "observation", "state"): agent_pos[1:],
|
||||||
|
|
Loading…
Reference in New Issue