From d8f53d93c09022a806388164f1242b0528adcef7 Mon Sep 17 00:00:00 2001 From: ChopinChen Date: Wed, 9 Apr 2025 14:47:22 +0800 Subject: [PATCH 1/3] add convert stats in parallel using multiple thread and decord video backend. --- .../v21/convert_dataset_v20_to_v21.py | 7 ++-- lerobot/common/datasets/v21/convert_stats.py | 34 +++++++++++++++++-- lerobot/common/datasets/video_utils.py | 18 ++++++++++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 176d16d0..be58188d 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -38,7 +38,7 @@ from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats +from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel V20 = "v2.0" V21 = "v2.1" @@ -57,14 +57,15 @@ def convert_dataset( repo_id: str, branch: str | None = None, num_workers: int = 4, + video_backend: str = "pyav", ): with SuppressWarnings(): - dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend) if (dataset.root / EPISODES_STATS_PATH).is_file(): (dataset.root / EPISODES_STATS_PATH).unlink() - convert_stats(dataset, num_workers=num_workers) + convert_stats_parallel(dataset, num_workers=num_workers) ref_stats = load_stats(dataset.root) check_aggregate_stats(dataset, ref_stats) diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index 4a20b427..a3184e13 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -16,6 +16,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np from tqdm import tqdm +from multiprocessing import cpu_count +from concurrent.futures import ProcessPoolExecutor, as_completed from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -30,7 +32,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_ return video_frames[ft_key].numpy() -def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False): ep_start_idx = dataset.episode_data_index["from"][ep_idx] ep_end_idx = dataset.episode_data_index["to"][ep_idx] ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) @@ -52,7 +54,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() } - dataset.meta.episodes_stats[ep_idx] = ep_stats + if not is_parallel: + dataset.meta.episodes_stats[ep_idx] = ep_stats + + return ep_stats, ep_idx def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): @@ -75,6 +80,31 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) +def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): + """Convert stats in parallel using multiple thread.""" + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + futures = [] + + max_workers = min(cpu_count(), num_workers) + if num_workers > 0: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for ep_idx in range(total_episodes): + futures.append( + executor.submit(convert_episode_stats, dataset, ep_idx, True) + ) + for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"): + ep_stats, ep_data = future.result() + dataset.meta.episodes_stats[ep_idx] = ep_data + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + def check_aggregate_stats( dataset: LeRobotDataset, reference_stats: dict[str, dict[str, np.ndarray]], diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..1e790913 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar +import decord import pyarrow as pa import torch import torchvision @@ -66,6 +67,8 @@ def decode_video_frames( return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) elif backend in ["pyav", "video_reader"]: return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + elif backend == "decord": + return decode_video_frames_decord(video_path, timestamps) else: raise ValueError(f"Unsupported video backend: {backend}") @@ -243,6 +246,21 @@ def decode_video_frames_torchcodec( return closest_frames +def decode_video_frames_decord( + video_path: Path | str, + timestamps: list[float], +) -> torch.Tensor:: + video_path = str(video_path) + vr = decord.VideoReader(video_path) + num_frames = len(vr) + frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) + indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) + frames = vr.get_batch(indices) + + frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255 + return frames_tensor + + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, From dcd0f5c5193e672ffedea32997b41ec81fef2e61 Mon Sep 17 00:00:00 2001 From: ChopinChen Date: Wed, 9 Apr 2025 15:00:35 +0800 Subject: [PATCH 2/3] add convert stats in parallel using multiple thread and decord video backend. --- lerobot/common/datasets/video_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 1e790913..129a09b0 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -24,6 +24,7 @@ from pathlib import Path from typing import Any, ClassVar import decord +import numpy as np import pyarrow as pa import torch import torchvision @@ -249,7 +250,7 @@ def decode_video_frames_torchcodec( def decode_video_frames_decord( video_path: Path | str, timestamps: list[float], -) -> torch.Tensor:: +) -> torch.Tensor: video_path = str(video_path) vr = decord.VideoReader(video_path) num_frames = len(vr) From f97bcd30e2a9f859be30cd7e0f49ec666c62fb0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 07:02:42 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../datasets/v21/convert_dataset_v20_to_v21.py | 5 ++++- lerobot/common/datasets/v21/convert_stats.py | 13 +++++-------- lerobot/common/datasets/video_utils.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index be58188d..53cacbef 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -38,7 +38,10 @@ from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel +from lerobot.common.datasets.v21.convert_stats import ( + check_aggregate_stats, + convert_stats_parallel, +) V20 = "v2.0" V21 = "v2.1" diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index a3184e13..cbce5f89 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -13,11 +13,10 @@ # limitations under the License. from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import cpu_count import numpy as np from tqdm import tqdm -from multiprocessing import cpu_count -from concurrent.futures import ProcessPoolExecutor, as_completed from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -56,7 +55,7 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: boo if not is_parallel: dataset.meta.episodes_stats[ep_idx] = ep_stats - + return ep_stats, ep_idx @@ -86,14 +85,12 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): print("Computing episodes stats") total_episodes = dataset.meta.total_episodes futures = [] - + max_workers = min(cpu_count(), num_workers) if num_workers > 0: with ThreadPoolExecutor(max_workers=max_workers) as executor: for ep_idx in range(total_episodes): - futures.append( - executor.submit(convert_episode_stats, dataset, ep_idx, True) - ) + futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True)) for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"): ep_stats, ep_data = future.result() dataset.meta.episodes_stats[ep_idx] = ep_data @@ -103,7 +100,7 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): for ep_idx in tqdm(range(total_episodes)): write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) - + def check_aggregate_stats( dataset: LeRobotDataset, diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 129a09b0..8d361fb8 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -257,11 +257,11 @@ def decode_video_frames_decord( frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) frames = vr.get_batch(indices) - + frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255 return frames_tensor - + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str,