86 lines
3.5 KiB
Python
86 lines
3.5 KiB
Python
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||
|
from lerobot.common.datasets.utils import write_episode_stats
|
||
|
|
||
|
|
||
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||
|
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||
|
sampled_indices = sample_indices(ep_len)
|
||
|
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||
|
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||
|
return video_frames[ft_key].numpy()
|
||
|
|
||
|
|
||
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||
|
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||
|
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||
|
|
||
|
ep_stats = {}
|
||
|
for key, ft in dataset.features.items():
|
||
|
if ft["dtype"] == "video":
|
||
|
# We sample only for videos
|
||
|
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||
|
else:
|
||
|
ep_ft_data = np.array(ep_data[key])
|
||
|
|
||
|
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||
|
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||
|
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||
|
|
||
|
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||
|
ep_stats[key] = {
|
||
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||
|
}
|
||
|
|
||
|
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||
|
|
||
|
|
||
|
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||
|
assert dataset.episodes is None
|
||
|
print("Computing episodes stats")
|
||
|
total_episodes = dataset.meta.total_episodes
|
||
|
if num_workers > 0:
|
||
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||
|
futures = {
|
||
|
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||
|
for ep_idx in range(total_episodes)
|
||
|
}
|
||
|
for future in tqdm(as_completed(futures), total=total_episodes):
|
||
|
future.result()
|
||
|
else:
|
||
|
for ep_idx in tqdm(range(total_episodes)):
|
||
|
convert_episode_stats(dataset, ep_idx)
|
||
|
|
||
|
for ep_idx in tqdm(range(total_episodes)):
|
||
|
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||
|
|
||
|
|
||
|
def check_aggregate_stats(
|
||
|
dataset: LeRobotDataset,
|
||
|
reference_stats: dict[str, dict[str, np.ndarray]],
|
||
|
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||
|
default_rtol_atol: tuple[float] = (5e-6, 0.0),
|
||
|
):
|
||
|
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||
|
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||
|
for key, ft in dataset.features.items():
|
||
|
# These values might need some fine-tuning
|
||
|
if ft["dtype"] == "video":
|
||
|
# to account for image sub-sampling
|
||
|
rtol, atol = video_rtol_atol
|
||
|
else:
|
||
|
rtol, atol = default_rtol_atol
|
||
|
|
||
|
for stat, val in agg_stats[key].items():
|
||
|
if key in reference_stats and stat in reference_stats[key]:
|
||
|
err_msg = f"feature='{key}' stats='{stat}'"
|
||
|
np.testing.assert_allclose(
|
||
|
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||
|
)
|