Add video decoding in dataset (WIP: issue with gray background)
This commit is contained in:
parent
d32a279435
commit
0fc94b81b3
|
@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
|
@ -33,7 +35,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = root
|
||||
storage = self._download_or_load_dataset()
|
||||
storage, meta_data = self._download_or_load_dataset()
|
||||
|
||||
if transform is not None and "video_id_to_path" in meta_data:
|
||||
# hack to access video paths
|
||||
assert isinstance(transform, Compose)
|
||||
for tf in transform:
|
||||
if isinstance(tf, DecodeVideoTransform):
|
||||
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
|
@ -99,7 +108,13 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
# required to not send cuda frames to cpu by default
|
||||
storage._storage.clear_device_()
|
||||
|
||||
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
||||
return storage, meta_data
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
|
|
|
@ -87,6 +87,21 @@ def make_offline_buffer(
|
|||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
transforms = []
|
||||
|
||||
# transforms = [
|
||||
# ViewSliceHorizonTransform(num_slices, cfg.policy.horizon),
|
||||
# KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
# DecodeVideoTransform(
|
||||
# data_dir=offline_buffer.data_dir,
|
||||
# device=cfg.device,
|
||||
# frame_rate=None,
|
||||
# in_keys=[("observation", "frame")],
|
||||
# out_keys=[("observation", "frame", "data")],
|
||||
# ),
|
||||
# ]
|
||||
|
||||
if normalize:
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
|
@ -94,10 +109,8 @@ def make_offline_buffer(
|
|||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
transforms.append(Prod(in_keys=img_keys, prod=1 / 255))
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import zipfile
|
|||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
|
@ -28,3 +29,26 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.14 * v
|
||||
g = y + -0.396 * u - 0.581 * v
|
||||
b = y + 2.029 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
|
@ -26,9 +29,48 @@ def visualize_dataset_cli(cfg: dict):
|
|||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
if frames.dtype != torch.uint8:
|
||||
logging.warning(f"frames are expected to be uint8 to {frames.dtype}")
|
||||
frames = frames.type(torch.uint8)
|
||||
|
||||
_, _, h, w = frames.shape
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
img_dir = Path(video_path.split(".")[0])
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(len(frames)):
|
||||
imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i])
|
||||
|
||||
ffmpeg_command = [
|
||||
"ffmpeg",
|
||||
"-r",
|
||||
str(fps),
|
||||
"-f",
|
||||
"image2",
|
||||
"-s",
|
||||
f"{w}x{h}",
|
||||
"-i",
|
||||
str(img_dir / "frame_%04d.png"),
|
||||
"-vcodec",
|
||||
"libx264",
|
||||
#'-vcodec', 'libx265',
|
||||
#'-vcodec', 'libaom-av1',
|
||||
"-crf",
|
||||
"0", # Lossless option
|
||||
"-pix_fmt",
|
||||
"yuv420p", # Specify pixel format
|
||||
video_path,
|
||||
# video_path.replace(".mp4", ".mkv")
|
||||
]
|
||||
subprocess.run(ffmpeg_command, check=True)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# clean temporary image directory
|
||||
# shutil.rmtree(img_dir)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
@ -61,7 +103,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
new_episode = ep_idx > current_ep_idx
|
||||
|
||||
if ep_idx < current_ep_idx:
|
||||
break
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
|
@ -71,7 +116,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir = Path(out_dir) / "videos"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
|
|
296
test.py
296
test.py
|
@ -4,24 +4,21 @@
|
|||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchrl
|
||||
from matplotlib import pyplot as plt
|
||||
from tensordict import TensorDict, TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchaudio.io import StreamReader
|
||||
from tensordict import TensorDict
|
||||
from torchaudio.utils import ffmpeg_utils
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms import Transform
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
NUM_STATE_CHANNELS = 12
|
||||
|
@ -42,15 +39,32 @@ def yuv_to_rgb(frames):
|
|||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.14 * v
|
||||
g = y + -0.396 * u - 0.581 * v
|
||||
b = y + 2.029 * u
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
|
@ -131,211 +145,6 @@ def get_frame_timestamps(frame_rate, num_frames):
|
|||
# return td
|
||||
|
||||
|
||||
class ViewSliceHorizonTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(self, num_slices, horizon):
|
||||
super().__init__()
|
||||
self.num_slices = num_slices
|
||||
self.horizon = horizon
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
td = td.view(self.num_slices, self.horizon)
|
||||
return td
|
||||
|
||||
|
||||
class KeepFrames(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positions,
|
||||
in_keys: Sequence[NestedKey],
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
):
|
||||
if isinstance(positions, list):
|
||||
assert isinstance(positions[0], int)
|
||||
# TODO(rcadene)L add support for `isinstance(positions, int)`?
|
||||
|
||||
self.positions = positions
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# we need set batch_size=[] before assigning a different shape to td[outkey]
|
||||
td.batch_size = []
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
td[outkey] = td[inkey][:, self.positions]
|
||||
return td
|
||||
|
||||
|
||||
class DecodeVideoTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||
format: str | None = None,
|
||||
frame_rate: int | None = None,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
):
|
||||
self.device = device
|
||||
self.format = format
|
||||
self.frame_rate = frame_rate
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.video_id_to_path = None
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
|
||||
def set_video_id_to_path(self, video_id_to_path):
|
||||
self.video_id_to_path = video_id_to_path
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
assert (
|
||||
self.video_id_to_path is not None
|
||||
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
||||
bsize = len(td[inkey]) # num episodes in the batch
|
||||
b_frames = []
|
||||
for i in range(bsize):
|
||||
assert (
|
||||
td["observation", "frame", "video_id"].ndim == 2
|
||||
), "We expect 2 dims. Respectively, number of episodes in the batch and number of observations"
|
||||
|
||||
ep_video_ids = td[inkey]["video_id"][i]
|
||||
timestamps = td[inkey]["timestamp"][i]
|
||||
frame_ids = td["frame_id"][i]
|
||||
|
||||
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
|
||||
assert unique_video_id
|
||||
|
||||
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
|
||||
assert is_ascending
|
||||
|
||||
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
|
||||
assert is_contiguous
|
||||
|
||||
FIRST_FRAME = 0 # noqa: N806
|
||||
video_id = ep_video_ids[FIRST_FRAME].item()
|
||||
video_path = self.video_id_to_path[video_id]
|
||||
first_frame_ts = timestamps[FIRST_FRAME].item()
|
||||
num_contiguous_frames = len(timestamps)
|
||||
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
s = StreamReader(video_path)
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
|
||||
b_frames.append(frames)
|
||||
|
||||
td[outkey] = torch.stack(b_frames)
|
||||
|
||||
if self.device == "cuda":
|
||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||
return td
|
||||
|
||||
|
||||
class VideoExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -349,7 +158,7 @@ class VideoExperienceReplay(TensorDictReplayBuffer):
|
|||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.data_dir = root / "2024_03_17_test_dataset"
|
||||
self.data_dir = root
|
||||
self.rb_dir = self.data_dir / "replay_buffer"
|
||||
|
||||
storage, meta_data = self._load_or_download()
|
||||
|
@ -454,7 +263,18 @@ if __name__ == "__main__":
|
|||
|
||||
import tqdm
|
||||
|
||||
print("FFmpeg Library versions:")
|
||||
for k, ver in ffmpeg_utils.get_versions().items():
|
||||
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
|
||||
|
||||
print("Available NVDEC Decoders:")
|
||||
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
horizon = 2
|
||||
batch_size = num_slices * horizon
|
||||
|
@ -470,6 +290,7 @@ if __name__ == "__main__":
|
|||
ViewSliceHorizonTransform(num_slices, horizon),
|
||||
KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
DecodeVideoTransform(
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
in_keys=[("observation", "frame")],
|
||||
|
@ -478,7 +299,7 @@ if __name__ == "__main__":
|
|||
]
|
||||
|
||||
replay_buffer = VideoExperienceReplay(
|
||||
root=Path("tmp"),
|
||||
root=data_dir,
|
||||
batch_size=batch_size,
|
||||
# prefetch=4,
|
||||
transform=Compose(*transforms),
|
||||
|
@ -489,52 +310,57 @@ if __name__ == "__main__":
|
|||
def test_time():
|
||||
replay_buffer = create_replay_buffer(device="cuda")
|
||||
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(2)):
|
||||
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
print(time.monotonic() - start)
|
||||
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(10)):
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot():
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda")
|
||||
rb_cpu = create_replay_buffer(device="cpu")
|
||||
rb_cpu = create_replay_buffer(device="cuda")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
for i in range(n_rows):
|
||||
set_seed(1337 + i)
|
||||
set_seed(seed + i)
|
||||
batch_cpu = rb_cpu.sample(include_info=False)
|
||||
print(batch_cpu["frame_id"])
|
||||
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
|
||||
print("episode cpu", batch_cpu["episode"].tolist())
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
set_seed(1337 + i)
|
||||
set_seed(seed + i)
|
||||
batch_cuda = rb_cuda.sample(include_info=False)
|
||||
print(batch_cuda["frame_id"])
|
||||
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
|
||||
print("episode cuda", batch_cuda["episode"].tolist())
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
frames = batch_cuda["observation", "image"].type(torch.uint8)
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][2].imshow(frames[0])
|
||||
|
||||
axes[0][0].set_title("Software decoder")
|
||||
axes[0][1].set_title("HW decoder")
|
||||
axes[0][2].set_title("uint8")
|
||||
plt.setp(axes, xticks=[], yticks=[])
|
||||
plt.tight_layout()
|
||||
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
||||
|
|
Loading…
Reference in New Issue