100 lines
4.1 KiB
Python
100 lines
4.1 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.datasets.utils import write_episode_stats
|
|
|
|
|
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
|
ep_len = dataset.meta.episodes[episode_index]["length"]
|
|
sampled_indices = sample_indices(ep_len)
|
|
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
|
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
|
return video_frames[ft_key].numpy()
|
|
|
|
|
|
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
|
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
|
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
|
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
|
|
|
ep_stats = {}
|
|
for key, ft in dataset.features.items():
|
|
if ft["dtype"] == "video":
|
|
# We sample only for videos
|
|
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
|
else:
|
|
ep_ft_data = np.array(ep_data[key])
|
|
|
|
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
|
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
|
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
|
|
|
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
|
ep_stats[key] = {
|
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
|
}
|
|
|
|
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
|
|
|
|
|
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
|
assert dataset.episodes is None
|
|
print("Computing episodes stats")
|
|
total_episodes = dataset.meta.total_episodes
|
|
if num_workers > 0:
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
futures = {
|
|
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
|
for ep_idx in range(total_episodes)
|
|
}
|
|
for future in tqdm(as_completed(futures), total=total_episodes):
|
|
future.result()
|
|
else:
|
|
for ep_idx in tqdm(range(total_episodes)):
|
|
convert_episode_stats(dataset, ep_idx)
|
|
|
|
for ep_idx in tqdm(range(total_episodes)):
|
|
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
|
|
|
|
|
def check_aggregate_stats(
|
|
dataset: LeRobotDataset,
|
|
reference_stats: dict[str, dict[str, np.ndarray]],
|
|
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
|
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
|
):
|
|
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
|
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
|
for key, ft in dataset.features.items():
|
|
# These values might need some fine-tuning
|
|
if ft["dtype"] == "video":
|
|
# to account for image sub-sampling
|
|
rtol, atol = video_rtol_atol
|
|
else:
|
|
rtol, atol = default_rtol_atol
|
|
|
|
for stat, val in agg_stats[key].items():
|
|
if key in reference_stats and stat in reference_stats[key]:
|
|
err_msg = f"feature='{key}' stats='{stat}'"
|
|
np.testing.assert_allclose(
|
|
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
|
)
|