This commit is contained in:
Chopin 2025-04-14 14:15:24 -07:00 committed by GitHub
commit 6585528b7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 5 deletions

View File

@ -38,7 +38,10 @@ from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats_parallel,
)
V20 = "v2.0" V20 = "v2.0"
V21 = "v2.1" V21 = "v2.1"
@ -57,14 +60,15 @@ def convert_dataset(
repo_id: str, repo_id: str,
branch: str | None = None, branch: str | None = None,
num_workers: int = 4, num_workers: int = 4,
video_backend: str = "pyav",
): ):
with SuppressWarnings(): with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend)
if (dataset.root / EPISODES_STATS_PATH).is_file(): if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink() (dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers) convert_stats_parallel(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root) ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats) check_aggregate_stats(dataset, ref_stats)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
@ -30,7 +31,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
return video_frames[ft_key].numpy() return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False):
ep_start_idx = dataset.episode_data_index["from"][ep_idx] ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx] ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
@ -52,7 +53,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
} }
dataset.meta.episodes_stats[ep_idx] = ep_stats if not is_parallel:
dataset.meta.episodes_stats[ep_idx] = ep_stats
return ep_stats, ep_idx
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
@ -75,6 +79,29 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
"""Convert stats in parallel using multiple thread."""
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
futures = []
max_workers = min(cpu_count(), num_workers)
if num_workers > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for ep_idx in range(total_episodes):
futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True))
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
ep_stats, ep_data = future.result()
dataset.meta.episodes_stats[ep_idx] = ep_data
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats( def check_aggregate_stats(
dataset: LeRobotDataset, dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]], reference_stats: dict[str, dict[str, np.ndarray]],

View File

@ -23,6 +23,8 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar from typing import Any, ClassVar
import decord
import numpy as np
import pyarrow as pa import pyarrow as pa
import torch import torch
import torchvision import torchvision
@ -66,6 +68,8 @@ def decode_video_frames(
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
elif backend == "decord":
return decode_video_frames_decord(video_path, timestamps)
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
@ -243,6 +247,21 @@ def decode_video_frames_torchcodec(
return closest_frames return closest_frames
def decode_video_frames_decord(
video_path: Path | str,
timestamps: list[float],
) -> torch.Tensor:
video_path = str(video_path)
vr = decord.VideoReader(video_path)
num_frames = len(vr)
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
frames = vr.get_batch(indices)
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
return frames_tensor
def encode_video_frames( def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,