121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import hydra
|
|
import imageio
|
|
import torch
|
|
|
|
from lerobot.common.datasets.factory import make_dataset
|
|
from lerobot.common.logger import log_output_dir
|
|
from lerobot.common.utils.utils import init_logging
|
|
|
|
NUM_EPISODES_TO_RENDER = 50
|
|
MAX_NUM_STEPS = 1000
|
|
FIRST_FRAME = 0
|
|
|
|
|
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
|
def visualize_dataset_cli(cfg: dict):
|
|
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
|
|
|
|
|
def cat_and_write_video(video_path, frames, fps):
|
|
frames = torch.cat(frames)
|
|
|
|
# Expects images in [0, 1].
|
|
frame = frames[0]
|
|
if frame.ndim == 4:
|
|
raise NotImplementedError("We currently dont support multiple timestamps.")
|
|
c, h, w = frame.shape
|
|
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
|
|
|
# sanity check that images are float32 in range [0,1]
|
|
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
|
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
|
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
|
|
|
# convert to channel last uint8 [0, 255]
|
|
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
|
frames = (frames * 255).type(torch.uint8)
|
|
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
|
|
|
|
|
def visualize_dataset(cfg: dict, out_dir=None):
|
|
if out_dir is None:
|
|
raise NotImplementedError()
|
|
|
|
init_logging()
|
|
log_output_dir(out_dir)
|
|
|
|
logging.info("make_dataset")
|
|
dataset = make_dataset(
|
|
cfg,
|
|
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
|
normalize=False,
|
|
)
|
|
|
|
logging.info("Start rendering episodes from offline buffer")
|
|
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
|
for video_path in video_paths:
|
|
logging.info(video_path)
|
|
return video_paths
|
|
|
|
|
|
def render_dataset(dataset, out_dir, max_num_episodes):
|
|
out_dir = Path(out_dir)
|
|
video_paths = []
|
|
threads = []
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=4,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
)
|
|
dl_iter = iter(dataloader)
|
|
|
|
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
|
|
logging.info(f"Rendering episode {ep_id}")
|
|
|
|
frames = {}
|
|
end_of_episode = False
|
|
while not end_of_episode:
|
|
item = next(dl_iter)
|
|
|
|
for im_key in dataset.image_keys:
|
|
# when first frame of episode, initialize frames dict
|
|
if im_key not in frames:
|
|
frames[im_key] = []
|
|
# add current frame to list of frames to render
|
|
frames[im_key].append(item[im_key])
|
|
|
|
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
for im_key in dataset.image_keys:
|
|
if len(dataset.image_keys) > 1:
|
|
im_name = im_key.replace("observation.images.", "")
|
|
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
|
|
else:
|
|
video_path = out_dir / f"episode_{ep_id}.mp4"
|
|
video_paths.append(video_path)
|
|
|
|
thread = threading.Thread(
|
|
target=cat_and_write_video,
|
|
args=(str(video_path), frames[im_key], dataset.fps),
|
|
)
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
logging.info("End of visualize_dataset")
|
|
return video_paths
|
|
|
|
|
|
if __name__ == "__main__":
|
|
visualize_dataset_cli()
|