This commit is contained in:
Chopin 2025-04-09 15:51:30 +00:00 committed by GitHub
commit a8901684c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 5 deletions

View File

@ -38,7 +38,10 @@ from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats_parallel,
)
V20 = "v2.0"
V21 = "v2.1"
@ -57,14 +60,15 @@ def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
video_backend: str = "pyav",
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
convert_stats_parallel(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)

View File

@ -13,6 +13,7 @@
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import numpy as np
from tqdm import tqdm
@ -30,7 +31,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
@ -52,7 +53,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
if not is_parallel:
dataset.meta.episodes_stats[ep_idx] = ep_stats
return ep_stats, ep_idx
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
@ -75,6 +79,29 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
"""Convert stats in parallel using multiple thread."""
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
futures = []
max_workers = min(cpu_count(), num_workers)
if num_workers > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for ep_idx in range(total_episodes):
futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True))
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
ep_stats, ep_data = future.result()
dataset.meta.episodes_stats[ep_idx] = ep_data
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],

View File

@ -23,6 +23,8 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
import decord
import numpy as np
import pyarrow as pa
import torch
import torchvision
@ -66,6 +68,8 @@ def decode_video_frames(
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
elif backend == "decord":
return decode_video_frames_decord(video_path, timestamps)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@ -243,6 +247,21 @@ def decode_video_frames_torchcodec(
return closest_frames
def decode_video_frames_decord(
video_path: Path | str,
timestamps: list[float],
) -> torch.Tensor:
video_path = str(video_path)
vr = decord.VideoReader(video_path)
num_frames = len(vr)
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
frames = vr.get_batch(indices)
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
return frames_tensor
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,