From 0fc94b81b37cc8ba80171f530adc50107448f7e7 Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 18 Mar 2024 16:24:05 +0000 Subject: [PATCH] Add video decoding in dataset (WIP: issue with gray background) --- lerobot/common/datasets/abstract.py | 19 +- lerobot/common/datasets/factory.py | 29 ++- lerobot/common/datasets/utils.py | 24 +++ lerobot/scripts/visualize_dataset.py | 53 ++++- test.py | 296 ++++++--------------------- 5 files changed, 172 insertions(+), 249 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e..d3eb2763 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose +from lerobot.common.datasets.transforms import DecodeVideoTransform + class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( @@ -33,7 +35,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): self.dataset_id = dataset_id self.shuffle = shuffle self.root = root - storage = self._download_or_load_dataset() + storage, meta_data = self._download_or_load_dataset() + + if transform is not None and "video_id_to_path" in meta_data: + # hack to access video paths + assert isinstance(transform, Compose) + for tf in transform: + if isinstance(tf, DecodeVideoTransform): + tf.set_video_id_to_path(meta_data["video_id_to_path"]) super().__init__( storage=storage, @@ -99,7 +108,13 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")) else: self.data_dir = self.root / self.dataset_id - return TensorStorage(TensorDict.load_memmap(self.data_dir)) + + storage = TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) + # required to not send cuda frames to cpu by default + storage._storage.clear_device_() + + meta_data = torch.load(self.data_dir / "meta_data.pth") + return storage, meta_data def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 3f4772c4..96c8bd22 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -87,17 +87,30 @@ def make_offline_buffer( prefetch=prefetch if isinstance(prefetch, int) else None, ) - if cfg.policy.name == "tdmpc": - img_keys = [] - for key in offline_buffer.image_keys: - img_keys.append(("next", *key)) - img_keys += offline_buffer.image_keys - else: - img_keys = offline_buffer.image_keys + transforms = [] - transforms = [Prod(in_keys=img_keys, prod=1 / 255)] + # transforms = [ + # ViewSliceHorizonTransform(num_slices, cfg.policy.horizon), + # KeepFrames(positions=[0], in_keys=[("observation")]), + # DecodeVideoTransform( + # data_dir=offline_buffer.data_dir, + # device=cfg.device, + # frame_rate=None, + # in_keys=[("observation", "frame")], + # out_keys=[("observation", "frame", "data")], + # ), + # ] if normalize: + if cfg.policy.name == "tdmpc": + img_keys = [] + for key in offline_buffer.image_keys: + img_keys.append(("next", *key)) + img_keys += offline_buffer.image_keys + else: + img_keys = offline_buffer.image_keys + transforms.append(Prod(in_keys=img_keys, prod=1 / 255)) + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec stats = offline_buffer.compute_or_load_stats() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 0ad43a65..ff5b1eb5 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -3,6 +3,7 @@ import zipfile from pathlib import Path import requests +import torch import tqdm @@ -28,3 +29,26 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return True else: return False + + +def yuv_to_rgb(frames): + assert frames.dtype == torch.uint8 + assert frames.ndim == 4 + assert frames.shape[1] == 3 + + frames = frames.cpu().to(torch.float) + y = frames[..., 0, :, :] + u = frames[..., 1, :, :] + v = frames[..., 2, :, :] + + y /= 255 + u = u / 255 - 0.5 + v = v / 255 - 0.5 + + r = y + 1.14 * v + g = y + -0.396 * u - 0.581 * v + b = y + 2.029 * u + + rgb = torch.stack([r, g, b], 1) + rgb = (rgb * 255).clamp(0, 255).to(torch.uint8) + return rgb diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 1bd63f6e..f14196f9 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,5 +1,8 @@ import logging +import shutil +import subprocess import threading +import time from pathlib import Path import einops @@ -26,9 +29,48 @@ def visualize_dataset_cli(cfg: dict): def cat_and_write_video(video_path, frames, fps): frames = torch.cat(frames) - assert frames.dtype == torch.uint8 + if frames.dtype != torch.uint8: + logging.warning(f"frames are expected to be uint8 to {frames.dtype}") + frames = frames.type(torch.uint8) + + _, _, h, w = frames.shape frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() - imageio.mimsave(video_path, frames, fps=fps) + + img_dir = Path(video_path.split(".")[0]) + if img_dir.exists(): + shutil.rmtree(img_dir) + img_dir.mkdir(parents=True, exist_ok=True) + + for i in range(len(frames)): + imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i]) + + ffmpeg_command = [ + "ffmpeg", + "-r", + str(fps), + "-f", + "image2", + "-s", + f"{w}x{h}", + "-i", + str(img_dir / "frame_%04d.png"), + "-vcodec", + "libx264", + #'-vcodec', 'libx265', + #'-vcodec', 'libaom-av1', + "-crf", + "0", # Lossless option + "-pix_fmt", + "yuv420p", # Specify pixel format + video_path, + # video_path.replace(".mp4", ".mkv") + ] + subprocess.run(ffmpeg_command, check=True) + + time.sleep(0.1) + + # clean temporary image directory + # shutil.rmtree(img_dir) def visualize_dataset(cfg: dict, out_dir=None): @@ -61,7 +103,10 @@ def visualize_dataset(cfg: dict, out_dir=None): # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames no_more_frames = offline_buffer._sampler._sample_list.numel() == 0 - new_episode = ep_idx != current_ep_idx + new_episode = ep_idx > current_ep_idx + + if ep_idx < current_ep_idx: + break if new_episode: logging.info(f"Visualizing episode {current_ep_idx}") @@ -71,7 +116,7 @@ def visualize_dataset(cfg: dict, out_dir=None): # append last observed frames (the ones after last action taken) frames[im_key].append(ep_td[("next", *im_key)]) - video_dir = Path(out_dir) / "visualize_dataset" + video_dir = Path(out_dir) / "videos" video_dir.mkdir(parents=True, exist_ok=True) if len(offline_buffer.image_keys) > 1: diff --git a/test.py b/test.py index da87a27e..bf8f877c 100644 --- a/test.py +++ b/test.py @@ -4,24 +4,21 @@ import subprocess from collections.abc import Callable from pathlib import Path -from typing import Sequence import einops import torch import torchaudio import torchrl from matplotlib import pyplot as plt -from tensordict import TensorDict, TensorDictBase -from tensordict.nn import dispatch -from tensordict.utils import NestedKey -from torchaudio.io import StreamReader +from tensordict import TensorDict +from torchaudio.utils import ffmpeg_utils from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.envs.transforms import Transform from torchrl.envs.transforms.transforms import Compose +from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform from lerobot.common.utils import set_seed NUM_STATE_CHANNELS = 12 @@ -42,15 +39,32 @@ def yuv_to_rgb(frames): u = u / 255 - 0.5 v = v / 255 - 0.5 - r = y + 1.14 * v - g = y + -0.396 * u - 0.581 * v - b = y + 2.029 * u + r = y + 1.13983 * v + g = y + -0.39465 * u - 0.58060 * v + b = y + 2.03211 * u rgb = torch.stack([r, g, b], 1) rgb = (rgb * 255).clamp(0, 255).to(torch.uint8) return rgb +def yuv_to_rgb_cv2(frames, return_hwc=True): + assert frames.dtype == torch.uint8 + assert frames.ndim == 4 + assert frames.shape[1] == 3 + frames = frames.cpu() + import cv2 + + frames = einops.rearrange(frames, "b c h w -> b h w c") + frames = frames.numpy() + frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames] + frames = [torch.from_numpy(frame) for frame in frames] + frames = torch.stack(frames) + if not return_hwc: + frames = einops.rearrange(frames, "b h w c -> b c h w") + return frames + + def count_frames(video_path): try: # Construct the ffprobe command to get the number of frames @@ -131,211 +145,6 @@ def get_frame_timestamps(frame_rate, num_frames): # return td -class ViewSliceHorizonTransform(Transform): - invertible = False - - def __init__(self, num_slices, horizon): - super().__init__() - self.num_slices = num_slices - self.horizon = horizon - - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: - td = td.view(self.num_slices, self.horizon) - return td - - -class KeepFrames(Transform): - invertible = False - - def __init__( - self, - positions, - in_keys: Sequence[NestedKey], - out_keys: Sequence[NestedKey] = None, - ): - if isinstance(positions, list): - assert isinstance(positions[0], int) - # TODO(rcadene)L add support for `isinstance(positions, int)`? - - self.positions = positions - if out_keys is None: - out_keys = in_keys - super().__init__(in_keys=in_keys, out_keys=out_keys) - - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: - # we need set batch_size=[] before assigning a different shape to td[outkey] - td.batch_size = [] - - for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: - continue - td[outkey] = td[inkey][:, self.positions] - return td - - -class DecodeVideoTransform(Transform): - invertible = False - - def __init__( - self, - device="cpu", - # format options are None=yuv420p (usually), rgb24, bgr24, etc. - format: str | None = None, - frame_rate: int | None = None, - width: int | None = None, - height: int | None = None, - in_keys: Sequence[NestedKey] = None, - out_keys: Sequence[NestedKey] = None, - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, - ): - self.device = device - self.format = format - self.frame_rate = frame_rate - self.width = width - self.height = height - self.video_id_to_path = None - if out_keys is None: - out_keys = in_keys - if in_keys_inv is None: - in_keys_inv = out_keys - if out_keys_inv is None: - out_keys_inv = in_keys - super().__init__( - in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv - ) - - def set_video_id_to_path(self, video_id_to_path): - self.video_id_to_path = video_id_to_path - - def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase: - # _reset is called once when the environment reset to normalize the first observation - tensordict_reset = self._call(tensordict_reset) - return tensordict_reset - - @dispatch(source="in_keys", dest="out_keys") - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - return self._call(tensordict) - - def _call(self, td: TensorDictBase) -> TensorDictBase: - assert ( - self.video_id_to_path is not None - ), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required." - - for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False): - # TODO(rcadene): don't know how to do `inkey not in td` - if td.get(inkey, None) is None: - continue - - bsize = len(td[inkey]) # num episodes in the batch - b_frames = [] - for i in range(bsize): - assert ( - td["observation", "frame", "video_id"].ndim == 2 - ), "We expect 2 dims. Respectively, number of episodes in the batch and number of observations" - - ep_video_ids = td[inkey]["video_id"][i] - timestamps = td[inkey]["timestamp"][i] - frame_ids = td["frame_id"][i] - - unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item() - assert unique_video_id - - is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item() - assert is_ascending - - is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item() - assert is_contiguous - - FIRST_FRAME = 0 # noqa: N806 - video_id = ep_video_ids[FIRST_FRAME].item() - video_path = self.video_id_to_path[video_id] - first_frame_ts = timestamps[FIRST_FRAME].item() - num_contiguous_frames = len(timestamps) - - filter_desc = [] - video_stream_kwgs = { - "frames_per_chunk": num_contiguous_frames, - "buffer_chunk_size": num_contiguous_frames, - } - - # choice of decoder - if self.device == "cuda": - video_stream_kwgs["hw_accel"] = "cuda" - video_stream_kwgs["decoder"] = "h264_cuvid" - else: - video_stream_kwgs["decoder"] = "h264" - - # resize - resize_width = self.width is not None - resize_height = self.height is not None - if resize_width or resize_height: - if self.device == "cuda": - assert resize_width and resize_height - video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"} - else: - scales = [] - if resize_width: - scales.append(f"width={self.width}") - if resize_height: - scales.append(f"height={self.height}") - filter_desc.append(f"scale={':'.join(scales)}") - - # choice of format - if self.format is not None: - if self.device == "cuda": - # TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp - raise NotImplementedError() - # filter_desc = f"scale=format={self.format}" - # filter_desc = f"scale_cuda=format={self.format}" - # filter_desc = f"scale_npp=format={self.format}" - else: - filter_desc.append(f"format=pix_fmts={self.format}") - - # choice of frame rate - if self.frame_rate is not None: - filter_desc.append(f"fps={self.frame_rate}") - - if len(filter_desc) > 0: - video_stream_kwgs["filter_desc"] = ",".join(filter_desc) - - # create a stream and load a certain number of frame at a certain frame rate - # TODO(rcadene): make sure it's the most optimal way to do it - s = StreamReader(video_path) - s.seek(first_frame_ts) - s.add_video_stream(**video_stream_kwgs) - s.fill_buffer() - (frames,) = s.pop_chunks() - - b_frames.append(frames) - - td[outkey] = torch.stack(b_frames) - - if self.device == "cuda": - # make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu - assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda" - return td - - class VideoExperienceReplay(TensorDictReplayBuffer): def __init__( self, @@ -349,7 +158,7 @@ class VideoExperienceReplay(TensorDictReplayBuffer): writer: Writer = None, transform: "torchrl.envs.Transform" = None, ): - self.data_dir = root / "2024_03_17_test_dataset" + self.data_dir = root self.rb_dir = self.data_dir / "replay_buffer" storage, meta_data = self._load_or_download() @@ -454,7 +263,18 @@ if __name__ == "__main__": import tqdm + print("FFmpeg Library versions:") + for k, ver in ffmpeg_utils.get_versions().items(): + print(f" {k}:\t{'.'.join(str(v) for v in ver)}") + + print("Available NVDEC Decoders:") + for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118 + if "cuvid" in k: + print(f" - {k}") + def create_replay_buffer(device): + data_dir = Path("tmp/2024_03_17_data_video/pusht") + num_slices = 1 horizon = 2 batch_size = num_slices * horizon @@ -470,6 +290,7 @@ if __name__ == "__main__": ViewSliceHorizonTransform(num_slices, horizon), KeepFrames(positions=[0], in_keys=[("observation")]), DecodeVideoTransform( + data_dir=data_dir, device=device, frame_rate=None, in_keys=[("observation", "frame")], @@ -478,7 +299,7 @@ if __name__ == "__main__": ] replay_buffer = VideoExperienceReplay( - root=Path("tmp"), + root=data_dir, batch_size=batch_size, # prefetch=4, transform=Compose(*transforms), @@ -489,52 +310,57 @@ if __name__ == "__main__": def test_time(): replay_buffer = create_replay_buffer(device="cuda") - start = time.time() + start = time.monotonic() for _ in tqdm.tqdm(range(2)): # include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1) replay_buffer.sample(include_info=False) torch.cuda.synchronize() - print(time.time() - start) + print(time.monotonic() - start) - start = time.time() + start = time.monotonic() for _ in tqdm.tqdm(range(10)): replay_buffer.sample(include_info=False) torch.cuda.synchronize() - print(time.time() - start) + print(time.monotonic() - start) - def test_plot(): + def test_plot(seed=1337): rb_cuda = create_replay_buffer(device="cuda") - rb_cpu = create_replay_buffer(device="cpu") + rb_cpu = create_replay_buffer(device="cuda") n_rows = 2 # len(replay_buffer) - fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0]) + fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0]) for i in range(n_rows): - set_seed(1337 + i) + set_seed(seed + i) batch_cpu = rb_cpu.sample(include_info=False) - print(batch_cpu["frame_id"]) + print("frame_ids cpu", batch_cpu["frame_id"].tolist()) + print("episode cpu", batch_cpu["episode"].tolist()) + print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist()) frames = batch_cpu["observation", "frame", "data"] - frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") - frames = yuv_to_rgb(frames) - frames = einops.rearrange(frames, "bt c h w -> bt h w c") - + frames = yuv_to_rgb(frames, return_hwc=True) assert frames.shape[0] == 1 axes[i][0].imshow(frames[0]) - set_seed(1337 + i) + set_seed(seed + i) batch_cuda = rb_cuda.sample(include_info=False) - print(batch_cuda["frame_id"]) + print("frame_ids cuda", batch_cuda["frame_id"].tolist()) + print("episode cuda", batch_cuda["episode"].tolist()) + print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist()) frames = batch_cuda["observation", "frame", "data"] - frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") - frames = yuv_to_rgb(frames) - frames = einops.rearrange(frames, "bt c h w -> bt h w c") - + frames = yuv_to_rgb(frames, return_hwc=True) assert frames.shape[0] == 1 axes[i][1].imshow(frames[0]) + frames = batch_cuda["observation", "image"].type(torch.uint8) + frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") + frames = einops.rearrange(frames, "bt c h w -> bt h w c") + assert frames.shape[0] == 1 + axes[i][2].imshow(frames[0]) + axes[0][0].set_title("Software decoder") axes[0][1].set_title("HW decoder") + axes[0][2].set_title("uint8") plt.setp(axes, xticks=[], yticks=[]) plt.tight_layout() fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)