331 lines
12 KiB
Python
331 lines
12 KiB
Python
# TODO(rcadene): add tests
|
|
# TODO(rcadene): what is the best format to store/load videos?
|
|
|
|
import subprocess
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import torch
|
|
import torchaudio
|
|
import torchrl
|
|
from matplotlib import pyplot as plt
|
|
from tensordict import TensorDict
|
|
from torchaudio.utils import ffmpeg_utils
|
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
from torchrl.envs.transforms.transforms import Compose
|
|
|
|
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
|
|
from lerobot.common.utils import set_seed
|
|
|
|
NUM_STATE_CHANNELS = 12
|
|
NUM_ACTION_CHANNELS = 12
|
|
|
|
|
|
def count_frames(video_path):
|
|
try:
|
|
# Construct the ffprobe command to get the number of frames
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=nb_frames",
|
|
"-of",
|
|
"default=nokey=1:noprint_wrappers=1",
|
|
video_path,
|
|
]
|
|
|
|
# Execute the ffprobe command and capture the output
|
|
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
|
|
# Convert the output to an integer
|
|
num_frames = int(result.stdout.strip())
|
|
|
|
return num_frames
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
return -1
|
|
|
|
|
|
def get_frame_rate(video_path):
|
|
try:
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=r_frame_rate",
|
|
"-of",
|
|
"default=nokey=1:noprint_wrappers=1",
|
|
video_path,
|
|
]
|
|
|
|
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
|
|
# The frame rate is typically represented as a fraction (e.g., "30000/1001").
|
|
# To convert it to a float, we can evaluate the fraction.
|
|
frame_rate = eval(result.stdout.strip())
|
|
|
|
return frame_rate
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
return -1
|
|
|
|
|
|
def get_frame_timestamps(frame_rate, num_frames):
|
|
timestamps = [(1 / frame_rate) * i for i in range(num_frames)]
|
|
return timestamps
|
|
|
|
|
|
# class ClearDeviceTransform(Transform):
|
|
# invertible = False
|
|
|
|
# def __init__(self):
|
|
# super().__init__()
|
|
|
|
# def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
# # _reset is called once when the environment reset to normalize the first observation
|
|
# tensordict_reset = self._call(tensordict_reset)
|
|
# return tensordict_reset
|
|
|
|
# @dispatch(source="in_keys", dest="out_keys")
|
|
# def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
# return self._call(tensordict)
|
|
|
|
# def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
# td.clear_device_()
|
|
# return td
|
|
|
|
|
|
class VideoExperienceReplay(TensorDictReplayBuffer):
|
|
def __init__(
|
|
self,
|
|
batch_size: int = None,
|
|
*,
|
|
root: Path = None,
|
|
pin_memory: bool = False,
|
|
prefetch: int = None,
|
|
sampler: SliceSampler = None,
|
|
collate_fn: Callable = None,
|
|
writer: Writer = None,
|
|
transform: "torchrl.envs.Transform" = None,
|
|
):
|
|
self.data_dir = root
|
|
self.rb_dir = self.data_dir / "replay_buffer"
|
|
|
|
storage, meta_data = self._load_or_download()
|
|
|
|
# hack to access video paths
|
|
assert isinstance(transform, Compose)
|
|
for tf in transform:
|
|
if isinstance(tf, DecodeVideoTransform):
|
|
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
|
|
|
super().__init__(
|
|
storage=storage,
|
|
sampler=sampler,
|
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
|
pin_memory=pin_memory,
|
|
prefetch=prefetch,
|
|
batch_size=batch_size,
|
|
transform=transform,
|
|
)
|
|
|
|
def _load_or_download(self, force_download=False):
|
|
if not force_download and self.data_dir.exists():
|
|
storage = TensorStorage(TensorDict.load_memmap(self.rb_dir))
|
|
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
|
else:
|
|
storage, meta_data = self._download()
|
|
torch.save(meta_data, self.data_dir / "meta_data.pth")
|
|
|
|
# required to not send cuda frames to cpu by default
|
|
storage._storage.clear_device_()
|
|
return storage, meta_data
|
|
|
|
def _download(self):
|
|
num_episodes = 1
|
|
video_id_to_path = {}
|
|
for episode_id in range(num_episodes):
|
|
video_path = torchaudio.utils.download_asset(
|
|
"tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4"
|
|
)
|
|
# several episodes can belong to the same video
|
|
video_id = episode_id
|
|
video_id_to_path[video_id] = video_path
|
|
|
|
print(f"{video_path=}")
|
|
num_frames = count_frames(video_path)
|
|
print(f"{num_frames=}")
|
|
frame_rate = get_frame_rate(video_path)
|
|
print(f"{frame_rate=}")
|
|
|
|
frame_timestamps = get_frame_timestamps(frame_rate, num_frames)
|
|
|
|
reward = torch.zeros(num_frames, 1, dtype=torch.float32)
|
|
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
|
state = torch.randn(num_frames, NUM_STATE_CHANNELS, dtype=torch.float32)
|
|
action = torch.randn(num_frames, NUM_ACTION_CHANNELS, dtype=torch.float32)
|
|
timestamp = torch.tensor(frame_timestamps)
|
|
frame_id = torch.arange(0, num_frames, 1)
|
|
episode_id_tensor = torch.tensor([episode_id] * num_frames, dtype=torch.int)
|
|
video_id_tensor = torch.tensor([video_id] * num_frames, dtype=torch.int)
|
|
|
|
# last step of demonstration is considered done
|
|
done[-1] = True
|
|
|
|
ep_td = TensorDict(
|
|
{
|
|
("observation", "frame", "video_id"): video_id_tensor[:-1],
|
|
("observation", "frame", "timestamp"): timestamp[:-1],
|
|
("observation", "state"): state[:-1],
|
|
"action": action[:-1],
|
|
"episode": episode_id_tensor[:-1],
|
|
"frame_id": frame_id[:-1],
|
|
("next", "observation", "frame", "video_id"): video_id_tensor[1:],
|
|
("next", "observation", "frame", "timestamp"): timestamp[1:],
|
|
("next", "observation", "state"): state[1:],
|
|
("next", "reward"): reward[1:],
|
|
("next", "done"): done[1:],
|
|
("next", "success"): success[1:],
|
|
},
|
|
batch_size=num_frames - 1,
|
|
)
|
|
|
|
# TODO:
|
|
total_frames = num_frames - 1
|
|
|
|
if episode_id == 0:
|
|
# hack to initialize tensordict data structure to store episodes
|
|
td_data = ep_td[0].expand(total_frames).memmap_like(self.rb_dir)
|
|
|
|
td_data[:] = ep_td
|
|
|
|
meta_data = {
|
|
"video_id_to_path": video_id_to_path,
|
|
}
|
|
|
|
return TensorStorage(td_data.lock_()), meta_data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
|
|
import tqdm
|
|
|
|
print("FFmpeg Library versions:")
|
|
for k, ver in ffmpeg_utils.get_versions().items():
|
|
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
|
|
|
|
print("Available NVDEC Decoders:")
|
|
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
|
|
if "cuvid" in k:
|
|
print(f" - {k}")
|
|
|
|
def create_replay_buffer(device, format=None):
|
|
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
|
|
|
num_slices = 1
|
|
horizon = 2
|
|
batch_size = num_slices * horizon
|
|
|
|
sampler = SliceSamplerWithoutReplacement(
|
|
num_slices=num_slices,
|
|
strict_length=True,
|
|
shuffle=False,
|
|
)
|
|
|
|
transforms = [
|
|
# ClearDeviceTransform(),
|
|
ViewSliceHorizonTransform(num_slices, horizon),
|
|
KeepFrames(positions=[0], in_keys=[("observation")]),
|
|
DecodeVideoTransform(
|
|
data_dir=data_dir,
|
|
device=device,
|
|
frame_rate=None,
|
|
format=format,
|
|
in_keys=[("observation", "frame")],
|
|
out_keys=[("observation", "frame", "data")],
|
|
),
|
|
]
|
|
|
|
replay_buffer = VideoExperienceReplay(
|
|
root=data_dir,
|
|
batch_size=batch_size,
|
|
# prefetch=4,
|
|
transform=Compose(*transforms),
|
|
sampler=sampler,
|
|
)
|
|
return replay_buffer
|
|
|
|
def test_time():
|
|
replay_buffer = create_replay_buffer(device="cuda")
|
|
|
|
start = time.monotonic()
|
|
for _ in tqdm.tqdm(range(2)):
|
|
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
|
replay_buffer.sample(include_info=False)
|
|
torch.cuda.synchronize()
|
|
print(time.monotonic() - start)
|
|
|
|
start = time.monotonic()
|
|
for _ in tqdm.tqdm(range(10)):
|
|
replay_buffer.sample(include_info=False)
|
|
torch.cuda.synchronize()
|
|
print(time.monotonic() - start)
|
|
|
|
def test_plot(seed=1337):
|
|
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
|
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
|
|
|
n_rows = 2 # len(replay_buffer)
|
|
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
|
for i in range(n_rows):
|
|
set_seed(seed + i)
|
|
batch_cpu = rb_cpu.sample(include_info=False)
|
|
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
|
|
print("episode cpu", batch_cpu["episode"].tolist())
|
|
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
|
frames = batch_cpu["observation", "frame", "data"]
|
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
|
assert frames.shape[0] == 1
|
|
axes[i][0].imshow(frames[0])
|
|
|
|
set_seed(seed + i)
|
|
batch_cuda = rb_cuda.sample(include_info=False)
|
|
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
|
|
print("episode cuda", batch_cuda["episode"].tolist())
|
|
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
|
frames = batch_cuda["observation", "frame", "data"]
|
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
|
assert frames.shape[0] == 1
|
|
axes[i][1].imshow(frames[0])
|
|
|
|
frames = batch_cuda["observation", "image"].type(torch.uint8)
|
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
|
assert frames.shape[0] == 1
|
|
axes[i][2].imshow(frames[0])
|
|
|
|
axes[0][0].set_title("Software decoder")
|
|
axes[0][1].set_title("HW decoder")
|
|
axes[0][2].set_title("uint8")
|
|
plt.setp(axes, xticks=[], yticks=[])
|
|
plt.tight_layout()
|
|
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
|
|
|
# test_time()
|
|
test_plot()
|