Per-episode stats (#521)
Co-authored-by: Remi Cadene <re.cadene@gmail.com> Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
parent
7c2bbee613
commit
8426c64f42
|
@ -148,6 +148,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
if image.dtype == np.float32 and image.max() == np.float32(255):
|
||||
# HACK: images are loaded as float32 but they actually encode uint8 data
|
||||
image = image.astype(np.uint8)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
|
|
|
@ -13,202 +13,148 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
Lower the power for less number of samples.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
For default arguments, we have:
|
||||
- from 1 to ~500, num_samples=100
|
||||
- at 1000, num_samples=177
|
||||
- at 2000, num_samples=299
|
||||
- at 5000, num_samples=594
|
||||
- at 10000, num_samples=1000
|
||||
- at 20000, num_samples=1681
|
||||
"""
|
||||
if dataset_len < min_num_samples:
|
||||
min_num_samples = dataset_len
|
||||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
def sample_indices(data_len: int) -> list[int]:
|
||||
num_samples = estimate_num_samples(data_len)
|
||||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
sampled_indices = sample_indices(len(image_paths))
|
||||
images = []
|
||||
for idx in sampled_indices:
|
||||
path = image_paths[idx]
|
||||
# we load as uint8 to reduce memory usage
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
images.append(img)
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
images = np.stack(images)
|
||||
return images
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
}
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
ep_ft_array = data # data is alreay a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
return stats_patterns
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
# for more info on why we need to set the same number of workers, see `load_from_videos`
|
||||
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
while counts.ndim < means.ndim:
|
||||
counts = np.expand_dims(counts, axis=-1)
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
# Compute the weighted mean
|
||||
weighted_means = means * counts
|
||||
total_mean = weighted_means.sum(axis=0) / total_count
|
||||
|
||||
# Compute the variance using the parallel algorithm
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
For instance:
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_mean = (mean of all data, weighted by counts)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
_assert_type_and_shape(stats_list)
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
|
|
@ -26,18 +26,17 @@ import PIL.Image
|
|||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
|
@ -52,10 +51,13 @@ from lerobot.common.datasets.utils import (
|
|||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
serialize_dict,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
write_parquet,
|
||||
)
|
||||
|
@ -90,6 +92,17 @@ class LeRobotDatasetMetadata:
|
|||
self.stats = load_stats(self.root)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
try:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
|
||||
Convert your dataset stats to the new format using this command:
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
|
||||
)
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
|
@ -228,7 +241,13 @@ class LeRobotDatasetMetadata:
|
|||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
|
@ -238,21 +257,19 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
||||
# ep_stats = serialize_dict(ep_stats)
|
||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
"""
|
||||
|
@ -309,6 +326,7 @@ class LeRobotDatasetMetadata:
|
|||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
|
@ -319,7 +337,7 @@ class LeRobotDatasetMetadata:
|
|||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.stats, obj.episodes = {}, []
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
|
@ -457,6 +475,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
||||
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
@ -479,10 +500,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
create_card: bool = True,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
|
@ -496,24 +520,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
repo_id=self.repo_id,
|
||||
private=private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
if branch:
|
||||
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
|
||||
|
||||
upload_folder(
|
||||
hub_api.upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
if create_card:
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not branch:
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
|
@ -630,7 +662,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
|
@ -660,8 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
|
@ -735,11 +766,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
|
@ -787,7 +820,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes in the dataset. This is not supported for now."
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_length == 0:
|
||||
|
@ -821,8 +854,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
|
@ -908,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
def consolidate(self, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
@ -928,17 +961,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -1056,7 +1079,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
|
|
|
@ -43,6 +43,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
|||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
|
@ -113,7 +114,16 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
|||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
|
@ -154,6 +164,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
|
|||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
|
@ -161,12 +175,29 @@ def load_info(local_dir: Path) -> dict:
|
|||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
|
@ -176,16 +207,42 @@ def load_tasks(local_dir: Path) -> dict:
|
|||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
@ -370,9 +427,9 @@ def create_empty_dataset_info(
|
|||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It performs the following:
|
||||
|
||||
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Removes the deprecated `stats.json` (by default)
|
||||
- Updates codebase_version in `info.json`
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1
|
||||
|
||||
import argparse
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
|
||||
def main(
|
||||
repo_id: str,
|
||||
test_branch: str | None = None,
|
||||
delete_old_stats: bool = False,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
raise FileExistsError("episodes_stats.jsonl already exists.")
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/")
|
||||
|
||||
if delete_old_stats:
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-old-stats",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Delete the deprecated `stats.json`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing compute",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
|
@ -0,0 +1,85 @@
|
|||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 0.0),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
|
@ -69,8 +69,8 @@ def decode_video_frames_torchvision(
|
|||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
@ -77,17 +78,29 @@ def create_stats_buffers(
|
|||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
@ -141,6 +154,7 @@ class Normalize(nn.Module):
|
|||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
|
|
|
@ -60,8 +60,6 @@ class RecordControlConfig(ControlConfig):
|
|||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
|
||||
run_compute_stats: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
|
|
|
@ -301,10 +301,7 @@ def record(
|
|||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if cfg.run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(cfg.run_compute_stats)
|
||||
dataset.consolidate()
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
|
|
@ -29,7 +29,7 @@ from tests.fixtures.constants import (
|
|||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
@ -142,6 +142,7 @@ def stats_factory():
|
|||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
|
@ -149,20 +150,38 @@ def stats_factory():
|
|||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
return stats
|
||||
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
) -> dict:
|
||||
episodes_stats = {}
|
||||
for episode_index in range(total_episodes):
|
||||
episodes_stats[episode_index] = {
|
||||
"episode_index": episode_index,
|
||||
"stats": stats_factory(features),
|
||||
}
|
||||
return episodes_stats
|
||||
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks_list = []
|
||||
for i in range(total_tasks):
|
||||
task_dict = {"task_index": i, "task": f"Perform action {i}."}
|
||||
tasks_list.append(task_dict)
|
||||
return tasks_list
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
@ -191,10 +210,10 @@ def episodes_factory(tasks_factory):
|
|||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
|
||||
episodes_list = []
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
|
@ -204,15 +223,13 @@ def episodes_factory(tasks_factory):
|
|||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes_list.append(
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
)
|
||||
episodes[ep_idx] = {
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
|
||||
return episodes_list
|
||||
return episodes
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
@ -236,7 +253,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes:
|
||||
for ep_dict in episodes.values():
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
|
@ -279,6 +296,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||
def lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
|
@ -288,6 +306,7 @@ def lerobot_dataset_metadata_factory(
|
|||
repo_id: str = DUMMY_REPO_ID,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
|
@ -296,6 +315,10 @@ def lerobot_dataset_metadata_factory(
|
|||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
|
@ -306,6 +329,7 @@ def lerobot_dataset_metadata_factory(
|
|||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
|
@ -329,6 +353,7 @@ def lerobot_dataset_metadata_factory(
|
|||
def lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
|
@ -344,6 +369,7 @@ def lerobot_dataset_factory(
|
|||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
|
@ -355,6 +381,8 @@ def lerobot_dataset_factory(
|
|||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
|
@ -370,6 +398,7 @@ def lerobot_dataset_factory(
|
|||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
hf_dataset=hf_dataset,
|
||||
|
@ -379,6 +408,7 @@ def lerobot_dataset_factory(
|
|||
repo_id=repo_id,
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
|
@ -406,7 +436,7 @@ def empty_lerobot_dataset_factory():
|
|||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
):
|
||||
) -> LeRobotDataset:
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
|
||||
)
|
||||
|
|
|
@ -7,7 +7,13 @@ import pyarrow.compute as pc
|
|||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -38,6 +44,20 @@ def stats_path(stats_factory):
|
|||
return _create_stats_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
return fpath
|
||||
|
||||
return _create_episodes_stats_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
|
@ -46,7 +66,7 @@ def tasks_path(tasks_factory):
|
|||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks)
|
||||
writer.write_all(tasks.values())
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
@ -60,7 +80,7 @@ def episode_path(episodes_factory):
|
|||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes)
|
||||
writer.write_all(episodes.values())
|
||||
return fpath
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
|
|
|
@ -4,7 +4,13 @@ import datasets
|
|||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
|
||||
|
@ -14,6 +20,8 @@ def mock_snapshot_download_factory(
|
|||
info_path,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
episodes_factory,
|
||||
|
@ -29,6 +37,7 @@ def mock_snapshot_download_factory(
|
|||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
|
@ -37,6 +46,10 @@ def mock_snapshot_download_factory(
|
|||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
|
@ -67,11 +80,11 @@ def mock_snapshot_download_factory(
|
|||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes:
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
|
@ -92,6 +105,8 @@ def mock_snapshot_download_factory(
|
|||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
|
|
|
@ -182,7 +182,7 @@ def test_camera(request, camera_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
|
||||
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
# TODO(rcadene): refactor
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||
|
@ -190,4 +190,4 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
|
|||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock)
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
_assert_type_and_shape,
|
||||
aggregate_feature_stats,
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
|
||||
def test_estimate_num_samples():
|
||||
assert estimate_num_samples(1) == 1
|
||||
assert estimate_num_samples(10) == 10
|
||||
assert estimate_num_samples(100) == 100
|
||||
assert estimate_num_samples(200) == 100
|
||||
assert estimate_num_samples(1000) == 177
|
||||
assert estimate_num_samples(2000) == 299
|
||||
assert estimate_num_samples(5000) == 594
|
||||
assert estimate_num_samples(10_000) == 1000
|
||||
assert estimate_num_samples(20_000) == 1681
|
||||
assert estimate_num_samples(50_000) == 3343
|
||||
assert estimate_num_samples(500_000) == 10_000
|
||||
|
||||
|
||||
def test_sample_indices():
|
||||
indices = sample_indices(10)
|
||||
assert len(indices) > 0
|
||||
assert indices[0] == 0
|
||||
assert indices[-1] == 9
|
||||
assert len(indices) == estimate_num_samples(10)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
assert isinstance(images, np.ndarray)
|
||||
assert images.shape[1:] == (3, 32, 32)
|
||||
assert images.dtype == np.uint8
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
"max": np.array([[7, 8, 9]]),
|
||||
"mean": np.array([[4.0, 5.0, 6.0]]),
|
||||
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_1(sample_array):
|
||||
expected = {
|
||||
"min": np.array([1, 4, 7]),
|
||||
"max": np.array([3, 6, 9]),
|
||||
"mean": np.array([2.0, 5.0, 8.0]),
|
||||
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_no_axis(sample_array):
|
||||
expected = {
|
||||
"min": np.array(1),
|
||||
"max": np.array(9),
|
||||
"mean": np.array(5.0),
|
||||
"std": np.array(2.5819889),
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
|
||||
def test_get_feature_stats_empty_array():
|
||||
array = np.array([])
|
||||
with pytest.raises(ValueError):
|
||||
get_feature_stats(array, axis=(0,), keepdims=True)
|
||||
|
||||
|
||||
def test_get_feature_stats_single_value():
|
||||
array = np.array([[1337]])
|
||||
result = get_feature_stats(array, axis=None, keepdims=True)
|
||||
np.testing.assert_equal(result["min"], np.array(1337))
|
||||
np.testing.assert_equal(result["max"], np.array(1337))
|
||||
np.testing.assert_equal(result["mean"], np.array(1337.0))
|
||||
np.testing.assert_equal(result["std"], np.array(0.0))
|
||||
np.testing.assert_equal(result["count"], np.array([1]))
|
||||
|
||||
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
valid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
}
|
||||
}
|
||||
]
|
||||
_assert_type_and_shape(valid_stats)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_invalid_type():
|
||||
invalid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"min": [1.0], # Not a numpy array
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
}
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_invalid_shape():
|
||||
invalid_stats = [
|
||||
{
|
||||
"feature1": {
|
||||
"count": np.array([1, 2]), # Wrong shape
|
||||
}
|
||||
}
|
||||
]
|
||||
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_aggregate_feature_stats():
|
||||
stats_ft_list = [
|
||||
{
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([1]),
|
||||
},
|
||||
{
|
||||
"min": np.array([2.0]),
|
||||
"max": np.array([12.0]),
|
||||
"mean": np.array([6.0]),
|
||||
"std": np.array([2.5]),
|
||||
"count": np.array([1]),
|
||||
},
|
||||
]
|
||||
result = aggregate_feature_stats(stats_ft_list)
|
||||
np.testing.assert_allclose(result["min"], np.array([1.0]))
|
||||
np.testing.assert_allclose(result["max"], np.array([12.0]))
|
||||
np.testing.assert_allclose(result["mean"], np.array([5.5]))
|
||||
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
|
||||
np.testing.assert_allclose(result["count"], np.array([2]))
|
||||
|
||||
|
||||
def test_aggregate_stats():
|
||||
all_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"min": [1, 2, 3],
|
||||
"max": [10, 20, 30],
|
||||
"mean": [5.5, 10.5, 15.5],
|
||||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
"observation.image": {
|
||||
"min": [2, 1, 0],
|
||||
"max": [15, 10, 5],
|
||||
"mean": [8.5, 5.5, 2.5],
|
||||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
|
||||
expected_agg_stats = {
|
||||
"observation.image": {
|
||||
"min": [1, 1, 0],
|
||||
"max": [15, 20, 30],
|
||||
"mean": [7.3, 7.5, 7.7],
|
||||
"std": [3.5317, 4.8267, 8.5581],
|
||||
"count": 25,
|
||||
},
|
||||
"observation.state": {
|
||||
"min": 1,
|
||||
"max": 15,
|
||||
"mean": 7.3,
|
||||
"std": 3.5317,
|
||||
"count": 25,
|
||||
},
|
||||
"extra_key_0": {
|
||||
"min": 5,
|
||||
"max": 25,
|
||||
"mean": 15.0,
|
||||
"std": 6.0,
|
||||
"count": 6,
|
||||
},
|
||||
"extra_key_1": {
|
||||
"min": 0,
|
||||
"max": 20,
|
||||
"mean": 10.0,
|
||||
"std": 5.0,
|
||||
"count": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# cast to numpy
|
||||
for ep_stats in all_stats:
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
# cast to numpy
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
results = aggregate_stats(all_stats)
|
||||
|
||||
for fkey in expected_agg_stats:
|
||||
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
||||
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
||||
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
||||
np.testing.assert_allclose(
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
|
@ -24,7 +24,6 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
|
|||
"""
|
||||
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
@ -45,7 +44,7 @@ from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_teleoperate(tmpdir, request, robot_type, mock):
|
||||
def test_teleoperate(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
|
@ -53,8 +52,7 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -70,15 +68,14 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_calibrate(tmpdir, request, robot_type, mock):
|
||||
def test_calibrate(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
@ -89,7 +86,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
def test_record_without_cameras(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
# Avoid using cameras
|
||||
|
@ -100,7 +97,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = Path(tmpdir) / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -108,7 +105,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
pass
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
@ -121,7 +118,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
play_sounds=False,
|
||||
|
@ -131,8 +127,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
tmpdir = Path(tmpdir)
|
||||
def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
|
@ -140,7 +135,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -148,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
pass
|
||||
|
||||
repo_id = "lerobot_test/debug"
|
||||
root = tmpdir / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
@ -180,7 +175,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
policy_cfg = ACTConfig()
|
||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
|
||||
|
||||
out_dir = tmpdir / "logger"
|
||||
out_dir = tmp_path / "logger"
|
||||
|
||||
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
policy.save_pretrained(pretrained_policy_path)
|
||||
|
@ -207,7 +202,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
num_image_writer_processes = 0
|
||||
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
eval_root = tmpdir / "data" / eval_repo_id
|
||||
eval_root = tmp_path / "data" / eval_repo_id
|
||||
|
||||
rec_eval_cfg = RecordControlConfig(
|
||||
repo_id=eval_repo_id,
|
||||
|
@ -218,7 +213,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
|
@ -240,7 +234,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
def test_resume_record(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
|
@ -248,7 +242,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -258,7 +252,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
|
@ -272,7 +266,6 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
local_files_only=True,
|
||||
num_episodes=1,
|
||||
)
|
||||
|
@ -291,7 +284,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
|
@ -299,7 +292,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -316,7 +309,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
|||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
|
@ -331,7 +324,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
|||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
|
@ -342,7 +334,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
|
@ -350,7 +342,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -367,7 +359,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
|||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
|
@ -382,7 +374,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
|||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
@ -395,7 +386,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
|||
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
|
@ -403,7 +394,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
|||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
|
@ -420,7 +411,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
|||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
root = tmp_path / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
rec_cfg = RecordControlConfig(
|
||||
|
@ -436,7 +427,6 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
|||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
|
|
|
@ -20,21 +20,14 @@ from copy import deepcopy
|
|||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
compute_stats,
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
|
@ -44,13 +37,11 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
|
@ -196,12 +187,12 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "dummy"
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
assert dataset[0]["task_index"] == 0
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
@ -209,9 +200,9 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
|
@ -219,9 +210,9 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
|
@ -229,9 +220,9 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
|
@ -239,9 +230,9 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
|
@ -249,9 +240,9 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"})
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
|
@ -261,7 +252,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
@ -271,7 +262,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
|||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
|
@ -307,7 +298,7 @@ def test_add_frame_image(image_dataset):
|
|||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -316,7 +307,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
|||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -326,7 +317,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -336,7 +327,7 @@ def test_add_frame_image_pil(image_dataset):
|
|||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
dataset.consolidate()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -463,67 +454,6 @@ def test_multidataset_frames():
|
|||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0])
|
||||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
full_batch = next(iter(dataloader))
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
for k, pattern in stats_patterns.items():
|
||||
full_batch[k] = full_batch[k].float()
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(
|
||||
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
)
|
||||
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||
|
||||
# test computed stats match expected stats
|
||||
for k in stats_patterns:
|
||||
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
|
||||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.meta.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
# for k in stats_patterns:
|
||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
|
@ -627,35 +557,6 @@ def test_backward_compatibility(repo_id):
|
|||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
data_a = torch.rand(30, dtype=torch.float32)
|
||||
data_b = torch.rand(20, dtype=torch.float32)
|
||||
data_c = torch.rand(20, dtype=torch.float32)
|
||||
|
||||
hf_dataset_1 = Dataset.from_dict(
|
||||
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires internet access")
|
||||
def test_create_branch():
|
||||
api = HfApi()
|
||||
|
|
|
@ -1,370 +0,0 @@
|
|||
"""
|
||||
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
|
||||
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
|
||||
we skip them for now in our CI.
|
||||
|
||||
Example to run backward compatiblity tests locally:
|
||||
```
|
||||
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
from tests.utils import require_package_arg
|
||||
|
||||
|
||||
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/img",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
import zarr
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
store = zarr.DirectoryStore(zarr_path)
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/camera0_rgb",
|
||||
shape=(num_frames, 96, 96, 3),
|
||||
chunks=(num_frames, 96, 96, 3),
|
||||
dtype=np.uint8,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_end_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_demo_start_pose",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_gripper_width",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
)
|
||||
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
|
||||
|
||||
store.close()
|
||||
|
||||
|
||||
def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||
import pickle
|
||||
|
||||
dataset_dict = {
|
||||
"observations": {
|
||||
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||
"state": np.random.randn(num_frames, 4),
|
||||
},
|
||||
"actions": np.random.randn(num_frames, 3),
|
||||
"rewards": np.random.randn(num_frames),
|
||||
"masks": np.random.randn(num_frames),
|
||||
"dones": np.array([False, True, True, True]),
|
||||
}
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
with open(pkl_path, "wb") as f:
|
||||
pickle.dump(dataset_dict, f)
|
||||
|
||||
|
||||
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||
import h5py
|
||||
|
||||
for ep_idx in range(num_episodes):
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||
with h5py.File(str(path_h5), "w") as f:
|
||||
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset(
|
||||
"observations/images/top",
|
||||
data=np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pandas
|
||||
|
||||
def write_parquet(key, timestamps, values):
|
||||
data = {
|
||||
"timestamp_utc": timestamps,
|
||||
key: values,
|
||||
}
|
||||
df = pandas.DataFrame(data)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
|
||||
|
||||
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
|
||||
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
||||
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
|
||||
|
||||
cam_key = "observation.images.cam_high"
|
||||
timestamps = []
|
||||
actions = []
|
||||
states = []
|
||||
frames = []
|
||||
# `+ num_episodes`` for buffer frames associated to episode_index=-1
|
||||
for i, frame_idx in enumerate(frame_indices):
|
||||
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
|
||||
action = np.random.randn(21).tolist()
|
||||
state = np.random.randn(21).tolist()
|
||||
ep_idx = episode_indices_mapping[i]
|
||||
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||
timestamps.append(t_utc)
|
||||
actions.append(action)
|
||||
states.append(state)
|
||||
frames.append(frame)
|
||||
|
||||
write_parquet(cam_key, timestamps, frames)
|
||||
write_parquet("observation.state", timestamps, states)
|
||||
write_parquet("action", timestamps, actions)
|
||||
write_parquet("episode_index", timestamps, episode_indices)
|
||||
|
||||
# write fake mp4 file for each episode
|
||||
for ep_idx in range(num_episodes):
|
||||
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||
|
||||
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = raw_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, vcodec="libx264")
|
||||
|
||||
|
||||
def _mock_download_raw(raw_dir, repo_id):
|
||||
if "wrist_gripper" in repo_id:
|
||||
_mock_download_raw_dora(raw_dir)
|
||||
elif "aloha" in repo_id:
|
||||
_mock_download_raw_aloha(raw_dir)
|
||||
elif "pusht" in repo_id:
|
||||
_mock_download_raw_pusht(raw_dir)
|
||||
elif "xarm" in repo_id:
|
||||
_mock_download_raw_xarm(raw_dir)
|
||||
elif "umi" in repo_id:
|
||||
_mock_download_raw_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(repo_id)
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
tmpdir = Path(tmpdir)
|
||||
out_dir = tmpdir / "out"
|
||||
raw_dir = tmpdir / "raw"
|
||||
# mkdir to skip download
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format="some_format",
|
||||
repo_id="user/dataset",
|
||||
local_dir=out_dir,
|
||||
force_override=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
|
||||
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
|
||||
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
|
||||
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
|
||||
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
|
||||
(None, "dora_parquet", "cadene/wrist_gripper", False),
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
raw_dir = tmpdir / f"{repo_id}_raw"
|
||||
_mock_download_raw(raw_dir, repo_id)
|
||||
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
lerobot_dataset = push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
|
||||
encoding={"vcodec": "libx264"},
|
||||
)
|
||||
|
||||
# minimal generic tests on the local directory containing LeRobotDataset
|
||||
assert (local_dir / "meta_data" / "info.json").exists()
|
||||
assert (local_dir / "meta_data" / "stats.safetensors").exists()
|
||||
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
|
||||
for i in range(num_episodes):
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
|
||||
assert (local_dir / "train" / "dataset_info.json").exists()
|
||||
assert (local_dir / "train" / "state.json").exists()
|
||||
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
|
||||
|
||||
# minimal generic tests on the item
|
||||
item = lerobot_dataset[0]
|
||||
assert "index" in item
|
||||
assert "episode_index" in item
|
||||
assert "timestamp" in item
|
||||
for cam_key in lerobot_dataset.camera_keys:
|
||||
assert cam_key in item
|
||||
|
||||
if make_test_data:
|
||||
# Check that only the first episode is selected.
|
||||
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
|
||||
num_frames = sum(
|
||||
i == lerobot_dataset.hf_dataset["episode_index"][0]
|
||||
for i in lerobot_dataset.hf_dataset["episode_index"]
|
||||
).item()
|
||||
assert (
|
||||
test_dataset.hf_dataset["episode_index"]
|
||||
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
||||
)
|
||||
for k in ["from", "to"]:
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
[
|
||||
# TODO(rcadene): add raw dataset test artifacts
|
||||
("pusht_zarr", "lerobot/pusht"),
|
||||
("xarm_pkl", "lerobot/xarm_lift_medium"),
|
||||
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
|
||||
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
|
||||
("dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
tmpdir = Path(tmpdir)
|
||||
raw_dir = tmpdir / f"{dataset_id}_raw"
|
||||
local_dir = tmpdir / repo_id
|
||||
|
||||
push_dataset_to_hub(
|
||||
raw_dir=raw_dir,
|
||||
raw_format=raw_format,
|
||||
repo_id=repo_id,
|
||||
push_to_hub=False,
|
||||
local_dir=local_dir,
|
||||
force_override=False,
|
||||
cache_dir=tmpdir / "cache",
|
||||
episodes=[0],
|
||||
)
|
||||
|
||||
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
|
||||
ds_reference = LeRobotDataset(repo_id)
|
||||
|
||||
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
|
||||
|
||||
def check_same_items(item1, item2):
|
||||
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||
|
||||
for key in item1:
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||
else:
|
||||
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||
|
||||
for i in range(len(ds_reference.hf_dataset)):
|
||||
item_reference = ds_reference.hf_dataset[i]
|
||||
item_actual = ds_actual.hf_dataset[i]
|
||||
check_same_items(item_reference, item_actual)
|
|
@ -23,8 +23,6 @@ pytest -sx 'tests/test_robots.py::test_robot[aloha-True]'
|
|||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -35,7 +33,7 @@ from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
|||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_robot(tmpdir, request, robot_type, mock):
|
||||
def test_robot(tmp_path, request, robot_type, mock):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): add compatibility with other robots
|
||||
|
@ -50,8 +48,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
|||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
calibration_dir = tmp_path / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
|
|
Loading…
Reference in New Issue