diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 1bd63f6e..685084cd 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -25,8 +25,10 @@ def visualize_dataset_cli(cfg: dict): def cat_and_write_video(video_path, frames, fps): + # Expects images in [0, 1]. frames = torch.cat(frames) - assert frames.dtype == torch.uint8 + assert frames.max() <= 1 and frames.min() >= 0 + frames = (255 * frames).to(dtype=torch.uint8) frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() imageio.mimsave(video_path, frames, fps=fps) @@ -59,7 +61,7 @@ def visualize_dataset(cfg: dict, out_dir=None): ep_td = offline_buffer.sample(1) ep_idx = ep_td["episode"][FIRST_FRAME].item() - # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames + # TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames no_more_frames = offline_buffer._sampler._sample_list.numel() == 0 new_episode = ep_idx != current_ep_idx @@ -69,7 +71,7 @@ def visualize_dataset(cfg: dict, out_dir=None): for im_key in offline_buffer.image_keys: if new_episode or no_more_frames: # append last observed frames (the ones after last action taken) - frames[im_key].append(ep_td[("next", *im_key)]) + frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key]) video_dir = Path(out_dir) / "visualize_dataset" video_dir.mkdir(parents=True, exist_ok=True) @@ -101,6 +103,9 @@ def visualize_dataset(cfg: dict, out_dir=None): logging.info("Ran out of frames") break + if current_ep_idx == NUM_EPISODES_TO_RENDER: + break + for thread in threads: thread.join()