This commit is contained in:
Jade Choghari 2025-02-22 18:10:12 +01:00
parent 6ca03b0dac
commit cf6e677485
2 changed files with 8 additions and 8 deletions

View File

@ -39,6 +39,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision,
encode_video_frames,
decode_video_frames_torchcodec,
)
from lerobot.common.utils.benchmark import TimeBenchmark
@ -67,10 +68,6 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if dataset.video:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
def get_directory_size(directory: Path) -> int:
@ -155,6 +152,10 @@ def decode_video_frames(
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
elif backend in ["torchcodec-cpu", "torchcodec-gpu"]:
# Only pass device once depending on the backend
device = "cpu" if backend == "torchcodec-cpu" else "cuda"
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, device=device)
else:
raise NotImplementedError(backend)
@ -188,7 +189,7 @@ def benchmark_decoding(
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
frames_np, original_frames_np = frames.cpu().numpy(), original_frames.cpu().numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(

View File

@ -52,10 +52,9 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision
decode_video_frames_torchvision, decode_video_frames_torchcodec
)
# let's define a custom fn
from torchcodec.decoders import VideoDecoder
def custom_collate_fn(batch):
"""
@ -81,7 +80,7 @@ def custom_collate_fn(batch):
)
# stack frames for this video key and add directly to the item
item[vid_key] = torch.stack(frames)
item[vid_key] = frames
# add item data (both video and non-video) to final_batch
for key, value in item.items():