From b5a2f460ea2b6c60f7a58f9b910a4dbf13af3c19 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 1 Mar 2024 14:20:55 +0000 Subject: [PATCH] fix bus error --- lerobot/common/datasets/pusht.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 76f4c6cd..3fb0b20e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -198,7 +198,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer): # load dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) - episode_ids = dataset_dict.get_episode_idxs() + episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) num_episodes = dataset_dict.meta["episode_ends"].shape[0] total_frames = dataset_dict["action"].shape[0] assert len( @@ -209,6 +209,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer): goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_body = get_goal_pose_body(goal_pos_angle) + imgs = torch.from_numpy(dataset_dict["img"]) + imgs = einops.rearrange(imgs, "b h w c -> b c h w") + states = torch.from_numpy(dataset_dict["state"]) + actions = torch.from_numpy(dataset_dict["action"]) + idx0 = 0 idxtd = 0 for episode_id in tqdm.tqdm(range(num_episodes)): @@ -218,10 +223,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer): assert (episode_ids[idx0:idx1] == episode_id).all() - image = torch.from_numpy(dataset_dict["img"][idx0:idx1]) - image = einops.rearrange(image, "b h w c -> b c h w") + image = imgs[idx0:idx1] - state = torch.from_numpy(dataset_dict["state"][idx0:idx1]) + state = states[idx0:idx1] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] @@ -255,12 +259,13 @@ class PushtExperienceReplay(TensorDictReplayBuffer): # last step of demonstration is considered done done[-1] = True + print("before " + """episode = TensorDict(""") episode = TensorDict( { ("observation", "image"): image[:-1], ("observation", "state"): agent_pos[:-1], - "action": torch.from_numpy(dataset_dict["action"][idx0:idx1])[:-1], - "episode": torch.from_numpy(episode_ids[idx0:idx1])[:-1], + "action": actions[idx0:idx1][:-1], + "episode": episode_ids[idx0:idx1][:-1], "frame_id": torch.arange(0, num_frames - 1, 1), ("next", "observation", "image"): image[1:], ("next", "observation", "state"): agent_pos[1:],