fix bus error
This commit is contained in:
parent
fa7f473142
commit
b5a2f460ea
|
@ -198,7 +198,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
# load
|
||||
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = dataset_dict.get_episode_idxs()
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
assert len(
|
||||
|
@ -209,6 +209,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
idx0 = 0
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
|
@ -218,10 +223,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = torch.from_numpy(dataset_dict["img"][idx0:idx1])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w")
|
||||
image = imgs[idx0:idx1]
|
||||
|
||||
state = torch.from_numpy(dataset_dict["state"][idx0:idx1])
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
@ -255,12 +259,13 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
print("before " + """episode = TensorDict(""")
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1],
|
||||
"episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
|
|
Loading…
Reference in New Issue