add convert stats in parallel using multiple thread and decord video backend.

This commit is contained in:
ChopinChen 2025-04-09 14:47:22 +08:00
parent 2c86fea78a
commit d8f53d93c0
3 changed files with 54 additions and 5 deletions

View File

@ -38,7 +38,7 @@ from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel
V20 = "v2.0"
V21 = "v2.1"
@ -57,14 +57,15 @@ def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
video_backend: str = "pyav",
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
convert_stats_parallel(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)

View File

@ -16,6 +16,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -30,7 +32,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
@ -52,8 +54,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
if not is_parallel:
dataset.meta.episodes_stats[ep_idx] = ep_stats
return ep_stats, ep_idx
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
@ -75,6 +80,31 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
"""Convert stats in parallel using multiple thread."""
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
futures = []
max_workers = min(cpu_count(), num_workers)
if num_workers > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for ep_idx in range(total_episodes):
futures.append(
executor.submit(convert_episode_stats, dataset, ep_idx, True)
)
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
ep_stats, ep_data = future.result()
dataset.meta.episodes_stats[ep_idx] = ep_data
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],

View File

@ -23,6 +23,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
import decord
import pyarrow as pa
import torch
import torchvision
@ -66,6 +67,8 @@ def decode_video_frames(
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
elif backend == "decord":
return decode_video_frames_decord(video_path, timestamps)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@ -243,6 +246,21 @@ def decode_video_frames_torchcodec(
return closest_frames
def decode_video_frames_decord(
video_path: Path | str,
timestamps: list[float],
) -> torch.Tensor::
video_path = str(video_path)
vr = decord.VideoReader(video_path)
num_frames = len(vr)
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
frames = vr.get_batch(indices)
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
return frames_tensor
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,