Add video decoding in dataset (WIP: issue with gray background)

This commit is contained in:
Cadene 2024-03-18 16:25:33 +00:00
parent 9954994a4b
commit 2bef00c317
1 changed files with 129 additions and 0 deletions

View File

@ -0,0 +1,129 @@
"""
usage: `python lerobot/scripts/convert_dataset_uint8_to_mp4.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
"""
import argparse
import shutil
from pathlib import Path
import torch
from tensordict import TensorDict
def convert_dataset_uint8_to_mp4(in_data_dir, out_data_dir, fps, overwrite_num_frames=None):
assert fps is not None and isinstance(fps, float)
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir)
out_data_dir = Path(out_data_dir)
# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_rb_dir = out_data_dir / "replay_buffer"
if out_rb_dir.exists():
shutil.rmtree(out_rb_dir)
num_frames = len(in_td_data) if overwrite_num_frames is None else overwrite_num_frames
# del in_td_data["observation", "image"]
# del in_td_data["next", "observation", "image"]
out_td_data = in_td_data[0].memmap_().clone()
out_td_data["observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
out_td_data["observation", "frame", "timestamp"] = torch.zeros(1)
out_td_data["next", "observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
out_td_data["next", "observation", "frame", "timestamp"] = torch.zeros(1)
out_td_data = out_td_data.expand(num_frames)
out_td_data = out_td_data.memmap_like(out_rb_dir)
out_vid_dir = out_data_dir / "videos"
out_vid_dir.mkdir(parents=True, exist_ok=True)
video_id_to_path = {}
for key in out_td_data.keys(include_nested=True, leaves_only=True):
if in_td_data.get(key, None) is None:
continue
if overwrite_num_frames is None:
out_td_data[key].copy_(in_td_data[key].clone())
else:
out_td_data[key][:num_frames].copy_(in_td_data[key][:num_frames].clone())
for i in range(num_frames):
video_id = in_td_data["episode"][i]
frame_id = in_td_data["frame_id"][i]
out_td_data["observation", "frame", "video_id"][i] = video_id
out_td_data["observation", "frame", "timestamp"][i] = frame_id / fps
out_td_data["next", "observation", "frame", "video_id"][i] = video_id
out_td_data["next", "observation", "frame", "timestamp"][i] = (frame_id + 1) / fps
video_id = video_id.item()
if video_id not in video_id_to_path:
video_id_to_path[video_id] = f"videos/episode_{video_id}.mp4"
# copy the first `n` frames so that we have real data
# make sure everything has been properly written
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
in_stats_path = Path(in_data_dir) / "stats.pth"
out_stats_path = Path(out_data_dir) / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)
meta_data = {
"video_id_to_path": video_id_to_path,
}
torch.save(meta_data, out_data_dir / "meta_data.pth")
# def write_to_mp4():
# buffer = io.BytesIO()
# swriter = StreamWriter(buffer, format="mp4")
# device = "cuda"
# c,h,w = in_td_data[0]["observation", "image"].shape
# swriter.add_video_stream(
# frame_rate=fps,
# width=w,
# height=h,
# # frame_rate=30000 / 1001,
# format="yuv444p",
# encoder="h264_nvenc",
# encoder_format="yuv444p",
# hw_accel=device,
# )
# for i in range(num_frames):
# ep_id = in_td_data[i]["episode"]
# data = in_td_data[i]["observation", "image"]
# with swriter.open():
# t0 = time.monotonic()
# data = data.to(device)
# swriter.write_video_chunk(0, data)
# elapsed = time.monotonic() - t0
# size = buffer.tell()
# print(f"{elapsed=}")
# print(f"{size=}")
# buffer.seek(0)
# video = buffer.read()
# vid_path = out_vid_dir / f"episode_{ep_id}.mp4"
# with open(vid_path, 'wb+') as f:
# f.write(video)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create dataset")
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
parser.add_argument("--fps", type=float)
args = parser.parse_args()
convert_dataset_uint8_to_mp4(args.in_data_dir, args.out_data_dir, args.fps)