Merge pull request #48 from alexander-soare/fix_visualization
Fix normalization of last frame and data type in visualization
This commit is contained in:
commit
e41c420a96
|
@ -25,8 +25,10 @@ def visualize_dataset_cli(cfg: dict):
|
||||||
|
|
||||||
|
|
||||||
def cat_and_write_video(video_path, frames, fps):
|
def cat_and_write_video(video_path, frames, fps):
|
||||||
|
# Expects images in [0, 1].
|
||||||
frames = torch.cat(frames)
|
frames = torch.cat(frames)
|
||||||
assert frames.dtype == torch.uint8
|
assert frames.max() <= 1 and frames.min() >= 0
|
||||||
|
frames = (255 * frames).to(dtype=torch.uint8)
|
||||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||||
imageio.mimsave(video_path, frames, fps=fps)
|
imageio.mimsave(video_path, frames, fps=fps)
|
||||||
|
|
||||||
|
@ -59,7 +61,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
ep_td = offline_buffer.sample(1)
|
ep_td = offline_buffer.sample(1)
|
||||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||||
|
|
||||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||||
new_episode = ep_idx != current_ep_idx
|
new_episode = ep_idx != current_ep_idx
|
||||||
|
|
||||||
|
@ -69,7 +71,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
for im_key in offline_buffer.image_keys:
|
for im_key in offline_buffer.image_keys:
|
||||||
if new_episode or no_more_frames:
|
if new_episode or no_more_frames:
|
||||||
# append last observed frames (the ones after last action taken)
|
# append last observed frames (the ones after last action taken)
|
||||||
frames[im_key].append(ep_td[("next", *im_key)])
|
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
|
||||||
|
|
||||||
video_dir = Path(out_dir) / "visualize_dataset"
|
video_dir = Path(out_dir) / "visualize_dataset"
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -101,6 +103,9 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
logging.info("Ran out of frames")
|
logging.info("Ran out of frames")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if current_ep_idx == NUM_EPISODES_TO_RENDER:
|
||||||
|
break
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue