Merge f97bcd30e2
into 5322417c03
This commit is contained in:
commit
a8901684c7
|
@ -38,7 +38,10 @@ from huggingface_hub import HfApi
|
|||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
from lerobot.common.datasets.v21.convert_stats import (
|
||||
check_aggregate_stats,
|
||||
convert_stats_parallel,
|
||||
)
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
@ -57,14 +60,15 @@ def convert_dataset(
|
|||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
video_backend: str = "pyav",
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True, video_backend=video_backend)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
convert_stats_parallel(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
@ -30,7 +31,7 @@ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_
|
|||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: bool = False):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
@ -52,7 +53,10 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
|||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
if not is_parallel:
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
return ep_stats, ep_idx
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
|
@ -75,6 +79,29 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
|||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
"""Convert stats in parallel using multiple thread."""
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
futures = []
|
||||
|
||||
max_workers = min(cpu_count(), num_workers)
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for ep_idx in range(total_episodes):
|
||||
futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True))
|
||||
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
|
||||
ep_stats, ep_data = future.result()
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_data
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
|
|
|
@ -23,6 +23,8 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import decord
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
|
@ -66,6 +68,8 @@ def decode_video_frames(
|
|||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
elif backend == "decord":
|
||||
return decode_video_frames_decord(video_path, timestamps)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
@ -243,6 +247,21 @@ def decode_video_frames_torchcodec(
|
|||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_decord(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
) -> torch.Tensor:
|
||||
video_path = str(video_path)
|
||||
vr = decord.VideoReader(video_path)
|
||||
num_frames = len(vr)
|
||||
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
|
||||
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
|
||||
frames = vr.get_batch(indices)
|
||||
|
||||
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
|
||||
return frames_tensor
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
|
|
Loading…
Reference in New Issue