544 lines
19 KiB
Python
544 lines
19 KiB
Python
# TODO(rcadene): add tests
|
|
# TODO(rcadene): what is the best format to store/load videos?
|
|
|
|
import subprocess
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import einops
|
|
import torch
|
|
import torchaudio
|
|
import torchrl
|
|
from matplotlib import pyplot as plt
|
|
from tensordict import TensorDict, TensorDictBase
|
|
from tensordict.nn import dispatch
|
|
from tensordict.utils import NestedKey
|
|
from torchaudio.io import StreamReader
|
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
from torchrl.envs.transforms import Transform
|
|
from torchrl.envs.transforms.transforms import Compose
|
|
|
|
from lerobot.common.utils import set_seed
|
|
|
|
NUM_STATE_CHANNELS = 12
|
|
NUM_ACTION_CHANNELS = 12
|
|
|
|
|
|
def yuv_to_rgb(frames):
|
|
assert frames.dtype == torch.uint8
|
|
assert frames.ndim == 4
|
|
assert frames.shape[1] == 3
|
|
|
|
frames = frames.cpu().to(torch.float)
|
|
y = frames[..., 0, :, :]
|
|
u = frames[..., 1, :, :]
|
|
v = frames[..., 2, :, :]
|
|
|
|
y /= 255
|
|
u = u / 255 - 0.5
|
|
v = v / 255 - 0.5
|
|
|
|
r = y + 1.14 * v
|
|
g = y + -0.396 * u - 0.581 * v
|
|
b = y + 2.029 * u
|
|
|
|
rgb = torch.stack([r, g, b], 1)
|
|
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
|
return rgb
|
|
|
|
|
|
def count_frames(video_path):
|
|
try:
|
|
# Construct the ffprobe command to get the number of frames
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=nb_frames",
|
|
"-of",
|
|
"default=nokey=1:noprint_wrappers=1",
|
|
video_path,
|
|
]
|
|
|
|
# Execute the ffprobe command and capture the output
|
|
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
|
|
# Convert the output to an integer
|
|
num_frames = int(result.stdout.strip())
|
|
|
|
return num_frames
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
return -1
|
|
|
|
|
|
def get_frame_rate(video_path):
|
|
try:
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=r_frame_rate",
|
|
"-of",
|
|
"default=nokey=1:noprint_wrappers=1",
|
|
video_path,
|
|
]
|
|
|
|
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
|
|
# The frame rate is typically represented as a fraction (e.g., "30000/1001").
|
|
# To convert it to a float, we can evaluate the fraction.
|
|
frame_rate = eval(result.stdout.strip())
|
|
|
|
return frame_rate
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
return -1
|
|
|
|
|
|
def get_frame_timestamps(frame_rate, num_frames):
|
|
timestamps = [(1 / frame_rate) * i for i in range(num_frames)]
|
|
return timestamps
|
|
|
|
|
|
# class ClearDeviceTransform(Transform):
|
|
# invertible = False
|
|
|
|
# def __init__(self):
|
|
# super().__init__()
|
|
|
|
# def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
# # _reset is called once when the environment reset to normalize the first observation
|
|
# tensordict_reset = self._call(tensordict_reset)
|
|
# return tensordict_reset
|
|
|
|
# @dispatch(source="in_keys", dest="out_keys")
|
|
# def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
# return self._call(tensordict)
|
|
|
|
# def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
# td.clear_device_()
|
|
# return td
|
|
|
|
|
|
class ViewSliceHorizonTransform(Transform):
|
|
invertible = False
|
|
|
|
def __init__(self, num_slices, horizon):
|
|
super().__init__()
|
|
self.num_slices = num_slices
|
|
self.horizon = horizon
|
|
|
|
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
# _reset is called once when the environment reset to normalize the first observation
|
|
tensordict_reset = self._call(tensordict_reset)
|
|
return tensordict_reset
|
|
|
|
@dispatch(source="in_keys", dest="out_keys")
|
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
return self._call(tensordict)
|
|
|
|
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
td = td.view(self.num_slices, self.horizon)
|
|
return td
|
|
|
|
|
|
class KeepFrames(Transform):
|
|
invertible = False
|
|
|
|
def __init__(
|
|
self,
|
|
positions,
|
|
in_keys: Sequence[NestedKey],
|
|
out_keys: Sequence[NestedKey] = None,
|
|
):
|
|
if isinstance(positions, list):
|
|
assert isinstance(positions[0], int)
|
|
# TODO(rcadene)L add support for `isinstance(positions, int)`?
|
|
|
|
self.positions = positions
|
|
if out_keys is None:
|
|
out_keys = in_keys
|
|
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
|
|
|
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
# _reset is called once when the environment reset to normalize the first observation
|
|
tensordict_reset = self._call(tensordict_reset)
|
|
return tensordict_reset
|
|
|
|
@dispatch(source="in_keys", dest="out_keys")
|
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
return self._call(tensordict)
|
|
|
|
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
# we need set batch_size=[] before assigning a different shape to td[outkey]
|
|
td.batch_size = []
|
|
|
|
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
|
if td.get(inkey, None) is None:
|
|
continue
|
|
td[outkey] = td[inkey][:, self.positions]
|
|
return td
|
|
|
|
|
|
class DecodeVideoTransform(Transform):
|
|
invertible = False
|
|
|
|
def __init__(
|
|
self,
|
|
device="cpu",
|
|
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
|
format: str | None = None,
|
|
frame_rate: int | None = None,
|
|
width: int | None = None,
|
|
height: int | None = None,
|
|
in_keys: Sequence[NestedKey] = None,
|
|
out_keys: Sequence[NestedKey] = None,
|
|
in_keys_inv: Sequence[NestedKey] | None = None,
|
|
out_keys_inv: Sequence[NestedKey] | None = None,
|
|
):
|
|
self.device = device
|
|
self.format = format
|
|
self.frame_rate = frame_rate
|
|
self.width = width
|
|
self.height = height
|
|
self.video_id_to_path = None
|
|
if out_keys is None:
|
|
out_keys = in_keys
|
|
if in_keys_inv is None:
|
|
in_keys_inv = out_keys
|
|
if out_keys_inv is None:
|
|
out_keys_inv = in_keys
|
|
super().__init__(
|
|
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
|
)
|
|
|
|
def set_video_id_to_path(self, video_id_to_path):
|
|
self.video_id_to_path = video_id_to_path
|
|
|
|
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
# _reset is called once when the environment reset to normalize the first observation
|
|
tensordict_reset = self._call(tensordict_reset)
|
|
return tensordict_reset
|
|
|
|
@dispatch(source="in_keys", dest="out_keys")
|
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
return self._call(tensordict)
|
|
|
|
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
assert (
|
|
self.video_id_to_path is not None
|
|
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
|
|
|
|
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
|
if td.get(inkey, None) is None:
|
|
continue
|
|
|
|
bsize = len(td[inkey]) # num episodes in the batch
|
|
b_frames = []
|
|
for i in range(bsize):
|
|
assert (
|
|
td["observation", "frame", "video_id"].ndim == 2
|
|
), "We expect 2 dims. Respectively, number of episodes in the batch and number of observations"
|
|
|
|
ep_video_ids = td[inkey]["video_id"][i]
|
|
timestamps = td[inkey]["timestamp"][i]
|
|
frame_ids = td["frame_id"][i]
|
|
|
|
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
|
|
assert unique_video_id
|
|
|
|
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
|
|
assert is_ascending
|
|
|
|
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
|
|
assert is_contiguous
|
|
|
|
FIRST_FRAME = 0 # noqa: N806
|
|
video_id = ep_video_ids[FIRST_FRAME].item()
|
|
video_path = self.video_id_to_path[video_id]
|
|
first_frame_ts = timestamps[FIRST_FRAME].item()
|
|
num_contiguous_frames = len(timestamps)
|
|
|
|
filter_desc = []
|
|
video_stream_kwgs = {
|
|
"frames_per_chunk": num_contiguous_frames,
|
|
"buffer_chunk_size": num_contiguous_frames,
|
|
}
|
|
|
|
# choice of decoder
|
|
if self.device == "cuda":
|
|
video_stream_kwgs["hw_accel"] = "cuda"
|
|
video_stream_kwgs["decoder"] = "h264_cuvid"
|
|
else:
|
|
video_stream_kwgs["decoder"] = "h264"
|
|
|
|
# resize
|
|
resize_width = self.width is not None
|
|
resize_height = self.height is not None
|
|
if resize_width or resize_height:
|
|
if self.device == "cuda":
|
|
assert resize_width and resize_height
|
|
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
|
else:
|
|
scales = []
|
|
if resize_width:
|
|
scales.append(f"width={self.width}")
|
|
if resize_height:
|
|
scales.append(f"height={self.height}")
|
|
filter_desc.append(f"scale={':'.join(scales)}")
|
|
|
|
# choice of format
|
|
if self.format is not None:
|
|
if self.device == "cuda":
|
|
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
|
raise NotImplementedError()
|
|
# filter_desc = f"scale=format={self.format}"
|
|
# filter_desc = f"scale_cuda=format={self.format}"
|
|
# filter_desc = f"scale_npp=format={self.format}"
|
|
else:
|
|
filter_desc.append(f"format=pix_fmts={self.format}")
|
|
|
|
# choice of frame rate
|
|
if self.frame_rate is not None:
|
|
filter_desc.append(f"fps={self.frame_rate}")
|
|
|
|
if len(filter_desc) > 0:
|
|
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
|
|
|
# create a stream and load a certain number of frame at a certain frame rate
|
|
# TODO(rcadene): make sure it's the most optimal way to do it
|
|
s = StreamReader(video_path)
|
|
s.seek(first_frame_ts)
|
|
s.add_video_stream(**video_stream_kwgs)
|
|
s.fill_buffer()
|
|
(frames,) = s.pop_chunks()
|
|
|
|
b_frames.append(frames)
|
|
|
|
td[outkey] = torch.stack(b_frames)
|
|
|
|
if self.device == "cuda":
|
|
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
|
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
|
return td
|
|
|
|
|
|
class VideoExperienceReplay(TensorDictReplayBuffer):
|
|
def __init__(
|
|
self,
|
|
batch_size: int = None,
|
|
*,
|
|
root: Path = None,
|
|
pin_memory: bool = False,
|
|
prefetch: int = None,
|
|
sampler: SliceSampler = None,
|
|
collate_fn: Callable = None,
|
|
writer: Writer = None,
|
|
transform: "torchrl.envs.Transform" = None,
|
|
):
|
|
self.data_dir = root / "2024_03_17_test_dataset"
|
|
self.rb_dir = self.data_dir / "replay_buffer"
|
|
|
|
storage, meta_data = self._load_or_download()
|
|
|
|
# hack to access video paths
|
|
assert isinstance(transform, Compose)
|
|
for tf in transform:
|
|
if isinstance(tf, DecodeVideoTransform):
|
|
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
|
|
|
super().__init__(
|
|
storage=storage,
|
|
sampler=sampler,
|
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
|
pin_memory=pin_memory,
|
|
prefetch=prefetch,
|
|
batch_size=batch_size,
|
|
transform=transform,
|
|
)
|
|
|
|
def _load_or_download(self, force_download=False):
|
|
if not force_download and self.data_dir.exists():
|
|
storage = TensorStorage(TensorDict.load_memmap(self.rb_dir))
|
|
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
|
else:
|
|
storage, meta_data = self._download()
|
|
torch.save(meta_data, self.data_dir / "meta_data.pth")
|
|
|
|
# required to not send cuda frames to cpu by default
|
|
storage._storage.clear_device_()
|
|
return storage, meta_data
|
|
|
|
def _download(self):
|
|
num_episodes = 1
|
|
video_id_to_path = {}
|
|
for episode_id in range(num_episodes):
|
|
video_path = torchaudio.utils.download_asset(
|
|
"tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4"
|
|
)
|
|
# several episodes can belong to the same video
|
|
video_id = episode_id
|
|
video_id_to_path[video_id] = video_path
|
|
|
|
print(f"{video_path=}")
|
|
num_frames = count_frames(video_path)
|
|
print(f"{num_frames=}")
|
|
frame_rate = get_frame_rate(video_path)
|
|
print(f"{frame_rate=}")
|
|
|
|
frame_timestamps = get_frame_timestamps(frame_rate, num_frames)
|
|
|
|
reward = torch.zeros(num_frames, 1, dtype=torch.float32)
|
|
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
|
state = torch.randn(num_frames, NUM_STATE_CHANNELS, dtype=torch.float32)
|
|
action = torch.randn(num_frames, NUM_ACTION_CHANNELS, dtype=torch.float32)
|
|
timestamp = torch.tensor(frame_timestamps)
|
|
frame_id = torch.arange(0, num_frames, 1)
|
|
episode_id_tensor = torch.tensor([episode_id] * num_frames, dtype=torch.int)
|
|
video_id_tensor = torch.tensor([video_id] * num_frames, dtype=torch.int)
|
|
|
|
# last step of demonstration is considered done
|
|
done[-1] = True
|
|
|
|
ep_td = TensorDict(
|
|
{
|
|
("observation", "frame", "video_id"): video_id_tensor[:-1],
|
|
("observation", "frame", "timestamp"): timestamp[:-1],
|
|
("observation", "state"): state[:-1],
|
|
"action": action[:-1],
|
|
"episode": episode_id_tensor[:-1],
|
|
"frame_id": frame_id[:-1],
|
|
("next", "observation", "frame", "video_id"): video_id_tensor[1:],
|
|
("next", "observation", "frame", "timestamp"): timestamp[1:],
|
|
("next", "observation", "state"): state[1:],
|
|
("next", "reward"): reward[1:],
|
|
("next", "done"): done[1:],
|
|
("next", "success"): success[1:],
|
|
},
|
|
batch_size=num_frames - 1,
|
|
)
|
|
|
|
# TODO:
|
|
total_frames = num_frames - 1
|
|
|
|
if episode_id == 0:
|
|
# hack to initialize tensordict data structure to store episodes
|
|
td_data = ep_td[0].expand(total_frames).memmap_like(self.rb_dir)
|
|
|
|
td_data[:] = ep_td
|
|
|
|
meta_data = {
|
|
"video_id_to_path": video_id_to_path,
|
|
}
|
|
|
|
return TensorStorage(td_data.lock_()), meta_data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
|
|
import tqdm
|
|
|
|
def create_replay_buffer(device):
|
|
num_slices = 1
|
|
horizon = 2
|
|
batch_size = num_slices * horizon
|
|
|
|
sampler = SliceSamplerWithoutReplacement(
|
|
num_slices=num_slices,
|
|
strict_length=True,
|
|
shuffle=False,
|
|
)
|
|
|
|
transforms = [
|
|
# ClearDeviceTransform(),
|
|
ViewSliceHorizonTransform(num_slices, horizon),
|
|
KeepFrames(positions=[0], in_keys=[("observation")]),
|
|
DecodeVideoTransform(
|
|
device=device,
|
|
frame_rate=None,
|
|
in_keys=[("observation", "frame")],
|
|
out_keys=[("observation", "frame", "data")],
|
|
),
|
|
]
|
|
|
|
replay_buffer = VideoExperienceReplay(
|
|
root=Path("tmp"),
|
|
batch_size=batch_size,
|
|
# prefetch=4,
|
|
transform=Compose(*transforms),
|
|
sampler=sampler,
|
|
)
|
|
return replay_buffer
|
|
|
|
def test_time():
|
|
replay_buffer = create_replay_buffer(device="cuda")
|
|
|
|
start = time.time()
|
|
for _ in tqdm.tqdm(range(2)):
|
|
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
|
replay_buffer.sample(include_info=False)
|
|
torch.cuda.synchronize()
|
|
print(time.time() - start)
|
|
|
|
start = time.time()
|
|
for _ in tqdm.tqdm(range(10)):
|
|
replay_buffer.sample(include_info=False)
|
|
torch.cuda.synchronize()
|
|
print(time.time() - start)
|
|
|
|
def test_plot():
|
|
rb_cuda = create_replay_buffer(device="cuda")
|
|
rb_cpu = create_replay_buffer(device="cpu")
|
|
|
|
n_rows = 2 # len(replay_buffer)
|
|
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
|
|
for i in range(n_rows):
|
|
set_seed(1337 + i)
|
|
batch_cpu = rb_cpu.sample(include_info=False)
|
|
print(batch_cpu["frame_id"])
|
|
frames = batch_cpu["observation", "frame", "data"]
|
|
|
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
|
frames = yuv_to_rgb(frames)
|
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
|
|
|
assert frames.shape[0] == 1
|
|
axes[i][0].imshow(frames[0])
|
|
|
|
set_seed(1337 + i)
|
|
batch_cuda = rb_cuda.sample(include_info=False)
|
|
print(batch_cuda["frame_id"])
|
|
frames = batch_cuda["observation", "frame", "data"]
|
|
|
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
|
frames = yuv_to_rgb(frames)
|
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
|
|
|
assert frames.shape[0] == 1
|
|
axes[i][1].imshow(frames[0])
|
|
|
|
axes[0][0].set_title("Software decoder")
|
|
axes[0][1].set_title("HW decoder")
|
|
plt.setp(axes, xticks=[], yticks=[])
|
|
plt.tight_layout()
|
|
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
|
|
|
# test_time()
|
|
test_plot()
|