fix bug
This commit is contained in:
parent
6ca03b0dac
commit
cf6e677485
|
@ -39,6 +39,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
from lerobot.common.utils.benchmark import TimeBenchmark
|
||||
|
||||
|
@ -67,10 +68,6 @@ def parse_int_or_none(value) -> int | None:
|
|||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
|
@ -155,6 +152,10 @@ def decode_video_frames(
|
|||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
elif backend in ["torchcodec-cpu", "torchcodec-gpu"]:
|
||||
# Only pass device once depending on the backend
|
||||
device = "cpu" if backend == "torchcodec-cpu" else "cuda"
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, device=device)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
@ -188,7 +189,7 @@ def benchmark_decoding(
|
|||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
frames_np, original_frames_np = frames.cpu().numpy(), original_frames.cpu().numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
|
|
|
@ -52,10 +52,9 @@ from lerobot.configs import parser
|
|||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision
|
||||
decode_video_frames_torchvision, decode_video_frames_torchcodec
|
||||
)
|
||||
# let's define a custom fn
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
"""
|
||||
|
@ -81,7 +80,7 @@ def custom_collate_fn(batch):
|
|||
)
|
||||
|
||||
# stack frames for this video key and add directly to the item
|
||||
item[vid_key] = torch.stack(frames)
|
||||
item[vid_key] = frames
|
||||
|
||||
# add item data (both video and non-video) to final_batch
|
||||
for key, value in item.items():
|
||||
|
|
Loading…
Reference in New Issue