add convert stats in parallel using multiple thread and decord video backend.
This commit is contained in:
parent
2c86fea78a
commit
d8f53d93c0
|
@ -38,7 +38,7 @@ from huggingface_hub import HfApi
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel
|
||||||
|
|
||||||
V20 = "v2.0"
|
V20 = "v2.0"
|
||||||
V21 = "v2.1"
|
V21 = "v2.1"
|
||||||
|
@ -57,14 +57,15 @@ def convert_dataset(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
|
video_backend: str = "pyav",
|
||||||
):
|
):
|
||||||
with SuppressWarnings():
|
with SuppressWarnings():
|
||||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend)
|
||||||
|
|
||||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||||
|
|
||||||
convert_stats(dataset, num_workers=num_workers)
|
convert_stats_parallel(dataset, num_workers=num_workers)
|
||||||
ref_stats = load_stats(dataset.root)
|
ref_stats = load_stats(dataset.root)
|
||||||
check_aggregate_stats(dataset, ref_stats)
|
check_aggregate_stats(dataset, ref_stats)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
@ -30,7 +32,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
||||||
return video_frames[ft_key].numpy()
|
return video_frames[ft_key].numpy()
|
||||||
|
|
||||||
|
|
||||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False):
|
||||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||||
|
@ -52,7 +54,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
if not is_parallel:
|
||||||
|
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||||
|
|
||||||
|
return ep_stats, ep_idx
|
||||||
|
|
||||||
|
|
||||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||||
|
@ -75,6 +80,31 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
|
||||||
|
"""Convert stats in parallel using multiple thread."""
|
||||||
|
assert dataset.episodes is None
|
||||||
|
print("Computing episodes stats")
|
||||||
|
total_episodes = dataset.meta.total_episodes
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
max_workers = min(cpu_count(), num_workers)
|
||||||
|
if num_workers > 0:
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
for ep_idx in range(total_episodes):
|
||||||
|
futures.append(
|
||||||
|
executor.submit(convert_episode_stats, dataset, ep_idx, True)
|
||||||
|
)
|
||||||
|
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
|
||||||
|
ep_stats, ep_data = future.result()
|
||||||
|
dataset.meta.episodes_stats[ep_idx] = ep_data
|
||||||
|
else:
|
||||||
|
for ep_idx in tqdm(range(total_episodes)):
|
||||||
|
convert_episode_stats(dataset, ep_idx)
|
||||||
|
|
||||||
|
for ep_idx in tqdm(range(total_episodes)):
|
||||||
|
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||||
|
|
||||||
|
|
||||||
def check_aggregate_stats(
|
def check_aggregate_stats(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||||
|
|
|
@ -23,6 +23,7 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
import decord
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -66,6 +67,8 @@ def decode_video_frames(
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
elif backend in ["pyav", "video_reader"]:
|
elif backend in ["pyav", "video_reader"]:
|
||||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||||
|
elif backend == "decord":
|
||||||
|
return decode_video_frames_decord(video_path, timestamps)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
@ -243,6 +246,21 @@ def decode_video_frames_torchcodec(
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_frames_decord(
|
||||||
|
video_path: Path | str,
|
||||||
|
timestamps: list[float],
|
||||||
|
) -> torch.Tensor::
|
||||||
|
video_path = str(video_path)
|
||||||
|
vr = decord.VideoReader(video_path)
|
||||||
|
num_frames = len(vr)
|
||||||
|
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
|
||||||
|
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
|
||||||
|
frames = vr.get_batch(indices)
|
||||||
|
|
||||||
|
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
|
||||||
|
return frames_tensor
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
Loading…
Reference in New Issue