diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 176d16d0..53cacbef 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -38,7 +38,10 @@ from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats +from lerobot.common.datasets.v21.convert_stats import ( + check_aggregate_stats, + convert_stats_parallel, +) V20 = "v2.0" V21 = "v2.1" @@ -57,14 +60,15 @@ def convert_dataset( repo_id: str, branch: str | None = None, num_workers: int = 4, + video_backend: str = "pyav", ): with SuppressWarnings(): - dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend) if (dataset.root / EPISODES_STATS_PATH).is_file(): (dataset.root / EPISODES_STATS_PATH).unlink() - convert_stats(dataset, num_workers=num_workers) + convert_stats_parallel(dataset, num_workers=num_workers) ref_stats = load_stats(dataset.root) check_aggregate_stats(dataset, ref_stats) diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index 4a20b427..cbce5f89 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -13,6 +13,7 @@ # limitations under the License. from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import cpu_count import numpy as np from tqdm import tqdm @@ -30,7 +31,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_ return video_frames[ft_key].numpy() -def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False): ep_start_idx = dataset.episode_data_index["from"][ep_idx] ep_end_idx = dataset.episode_data_index["to"][ep_idx] ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) @@ -52,7 +53,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() } - dataset.meta.episodes_stats[ep_idx] = ep_stats + if not is_parallel: + dataset.meta.episodes_stats[ep_idx] = ep_stats + + return ep_stats, ep_idx def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): @@ -75,6 +79,29 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) +def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): + """Convert stats in parallel using multiple thread.""" + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + futures = [] + + max_workers = min(cpu_count(), num_workers) + if num_workers > 0: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for ep_idx in range(total_episodes): + futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True)) + for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"): + ep_stats, ep_data = future.result() + dataset.meta.episodes_stats[ep_idx] = ep_data + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + def check_aggregate_stats( dataset: LeRobotDataset, reference_stats: dict[str, dict[str, np.ndarray]], diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..8d361fb8 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -23,6 +23,8 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar +import decord +import numpy as np import pyarrow as pa import torch import torchvision @@ -66,6 +68,8 @@ def decode_video_frames( return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) elif backend in ["pyav", "video_reader"]: return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + elif backend == "decord": + return decode_video_frames_decord(video_path, timestamps) else: raise ValueError(f"Unsupported video backend: {backend}") @@ -243,6 +247,21 @@ def decode_video_frames_torchcodec( return closest_frames +def decode_video_frames_decord( + video_path: Path | str, + timestamps: list[float], +) -> torch.Tensor: + video_path = str(video_path) + vr = decord.VideoReader(video_path) + num_frames = len(vr) + frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) + indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) + frames = vr.get_batch(indices) + + frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255 + return frames_tensor + + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str,