diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index e9015d2c..2eaf1c1c 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -148,6 +148,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T action = zarr_data["action"][:] image = zarr_data["img"] # (b, h, w, c) + if image.dtype == np.float32 and image.max() == np.float32(255): + # HACK: images are loaded as float32 but they actually encode uint8 data + image = image.astype(np.uint8) + episode_data_index = { "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), "to": zarr_data.meta["episode_ends"], diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index c6211699..7519c743 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -13,202 +13,148 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from math import ceil +import numpy as np -import einops -import torch -import tqdm +from lerobot.common.datasets.utils import load_image_as_numpy -def get_stats_einops_patterns(dataset, num_workers=0): - """These einops patterns will be used to aggregate batches and compute statistics. +def estimate_num_samples( + dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 +) -> int: + """Heuristic to estimate the number of samples based on dataset size. + The power controls the sample growth relative to dataset size. + Lower the power for less number of samples. - Note: We assume the images are in channel first format + For default arguments, we have: + - from 1 to ~500, num_samples=100 + - at 1000, num_samples=177 + - at 2000, num_samples=299 + - at 5000, num_samples=594 + - at 10000, num_samples=1000 + - at 20000, num_samples=1681 """ + if dataset_len < min_num_samples: + min_num_samples = dataset_len + return max(min_num_samples, min(int(dataset_len**power), max_num_samples)) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=2, - shuffle=False, - ) - batch = next(iter(dataloader)) - stats_patterns = {} +def sample_indices(data_len: int) -> list[int]: + num_samples = estimate_num_samples(data_len) + return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() - for key in dataset.features: - # sanity check that tensors are not float64 - assert batch[key].dtype != torch.float64 - # if isinstance(feats_type, (VideoFrame, Image)): - if key in dataset.meta.camera_keys: - # sanity check that images are channel first - _, c, h, w = batch[key].shape - assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" +def sample_images(image_paths: list[str]) -> np.ndarray: + sampled_indices = sample_indices(len(image_paths)) + images = [] + for idx in sampled_indices: + path = image_paths[idx] + # we load as uint8 to reduce memory usage + img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) + images.append(img) - # sanity check that images are float32 in range [0,1] - assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" - assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" - assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + images = np.stack(images) + return images - stats_patterns[key] = "b c h w -> c 1 1" - elif batch[key].ndim == 2: - stats_patterns[key] = "b c -> c " - elif batch[key].ndim == 1: - stats_patterns[key] = "b -> 1" + +def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: + return { + "min": np.min(array, axis=axis, keepdims=keepdims), + "max": np.max(array, axis=axis, keepdims=keepdims), + "mean": np.mean(array, axis=axis, keepdims=keepdims), + "std": np.std(array, axis=axis, keepdims=keepdims), + "count": np.array([len(array)]), + } + + +def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: + ep_stats = {} + for key, data in episode_data.items(): + if features[key]["dtype"] == "string": + continue # HACK: we should receive np.arrays of strings + elif features[key]["dtype"] in ["image", "video"]: + ep_ft_array = sample_images(data) # data is a list of image paths + axes_to_reduce = (0, 2, 3) # keep channel dim + keepdims = True else: - raise ValueError(f"{key}, {batch[key].shape}") + ep_ft_array = data # data is alreay a np.ndarray + axes_to_reduce = 0 # compute stats over the first axis + keepdims = data.ndim == 1 # keep as np.array - return stats_patterns + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + + # finally, we normalize and remove batch dim for images + if features[key]["dtype"] in ["image", "video"]: + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + } + + return ep_stats -def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None): - """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" - if max_num_samples is None: - max_num_samples = len(dataset) - - # for more info on why we need to set the same number of workers, see `load_from_videos` - stats_patterns = get_stats_einops_patterns(dataset, num_workers) - - # mean and std will be computed incrementally while max and min will track the running value. - mean, std, max, min = {}, {}, {}, {} - for key in stats_patterns: - mean[key] = torch.tensor(0.0).float() - std[key] = torch.tensor(0.0).float() - max[key] = torch.tensor(-float("inf")).float() - min[key] = torch.tensor(float("inf")).float() - - def create_seeded_dataloader(dataset, batch_size, seed): - generator = torch.Generator() - generator.manual_seed(seed) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=batch_size, - shuffle=True, - drop_last=False, - generator=generator, - ) - return dataloader - - # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get - # surprises when rerunning the sampler. - first_batch = None - running_item_count = 0 # for online mean computation - dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - if first_batch is None: - first_batch = deepcopy(batch) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation. - batch_mean = einops.reduce(batch[key], pattern, "mean") - # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents - # the update step, N is the running item count, B is this batch size, x̄ is the running mean, - # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields - # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - - if i == ceil(max_num_samples / batch_size) - 1: - break - - first_batch_ = None - running_item_count = 0 # for online std computation - dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - # Sanity check to make sure the batches are still in the same order as before. - if first_batch_ is None: - first_batch_ = deepcopy(batch) - for key in stats_patterns: - assert torch.equal(first_batch_[key], first_batch[key]) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation (where the mean is over squared - # residuals).See notes in the mean computation loop above. - batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count - - if i == ceil(max_num_samples / batch_size) - 1: - break - - for key in stats_patterns: - std[key] = torch.sqrt(std[key]) - - stats = {} - for key in stats_patterns: - stats[key] = { - "mean": mean[key], - "std": std[key], - "max": max[key], - "min": min[key], - } - return stats +def _assert_type_and_shape(stats_list: list[dict[str, dict]]): + for i in range(len(stats_list)): + for fkey in stats_list[i]: + for k, v in stats_list[i][fkey].items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." + ) + if v.ndim == 0: + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + if k == "count" and v.shape != (1,): + raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") + if "image" in fkey and k != "count" and v.shape != (3, 1, 1): + raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") -def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: - """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. +def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregates stats for a single feature.""" + means = np.stack([s["mean"] for s in stats_ft_list]) + variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) + counts = np.stack([s["count"] for s in stats_ft_list]) + total_count = counts.sum(axis=0) - The final stats will have the union of all data keys from each of the datasets. + # Prepare weighted mean by matching number of dimensions + while counts.ndim < means.ndim: + counts = np.expand_dims(counts, axis=-1) - The final stats will have the union of all data keys from each of the datasets. For instance: - - new_max = max(max_dataset_0, max_dataset_1, ...) + # Compute the weighted mean + weighted_means = means * counts + total_mean = weighted_means.sum(axis=0) / total_count + + # Compute the variance using the parallel algorithm + delta_means = means - total_mean + weighted_variances = (variances + delta_means**2) * counts + total_variance = weighted_variances.sum(axis=0) / total_count + + return { + "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), + "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), + "mean": total_mean, + "std": np.sqrt(total_variance), + "count": total_count, + } + + +def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregate stats from multiple compute_stats outputs into a single set of stats. + + The final stats will have the union of all data keys from each of the stats dicts. + + For instance: - new_min = min(min_dataset_0, min_dataset_1, ...) - - new_mean = (mean of all data) + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_mean = (mean of all data, weighted by counts) - new_std = (std of all data) """ - data_keys = set() - for dataset in ls_datasets: - data_keys.update(dataset.meta.stats.keys()) - stats = {k: {} for k in data_keys} - for data_key in data_keys: - for stat_key in ["min", "max"]: - # compute `max(dataset_0["max"], dataset_1["max"], ...)` - stats[data_key][stat_key] = einops.reduce( - torch.stack( - [ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats], - dim=0, - ), - "n ... -> ...", - stat_key, - ) - total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats) - # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective - # dataset, then divide by total_samples to get the overall "mean". - # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of - # numerical overflow! - stats[data_key]["mean"] = sum( - d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples) - for d in ls_datasets - if data_key in d.meta.stats - ) - # The derivation for standard deviation is a little more involved but is much in the same spirit as - # the computation of the mean. - # Given two sets of data where the statistics are known: - # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] - # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined - # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of - # numerical overflow! - stats[data_key]["std"] = torch.sqrt( - sum( - ( - d.meta.stats[data_key]["std"] ** 2 - + (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2 - ) - * (d.num_frames / total_samples) - for d in ls_datasets - if data_key in d.meta.stats - ) - ) - return stats + + _assert_type_and_shape(stats_list) + + data_keys = {key for stats in stats_list for key in stats} + aggregated_stats = {key: {} for key in data_keys} + + for key in data_keys: + stats_with_key = [stats[key] for stats in stats_list if key in stats] + aggregated_stats[key] = aggregate_feature_stats(stats_with_key) + + return aggregated_stats diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5c4ae68e..c7f0b2b3 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -26,18 +26,17 @@ import PIL.Image import torch import torch.utils from datasets import load_dataset -from huggingface_hub import create_repo, snapshot_download, upload_folder +from huggingface_hub import HfApi, snapshot_download -from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats +from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, - EPISODES_PATH, INFO_PATH, - STATS_PATH, TASKS_PATH, append_jsonlines, + backward_compatible_episodes_stats, check_delta_timestamps, check_frame_features, check_timestamps_sync, @@ -52,10 +51,13 @@ from lerobot.common.datasets.utils import ( get_hub_safe_version, hf_transform_to_torch, load_episodes, + load_episodes_stats, load_info, load_stats, load_tasks, - serialize_dict, + write_episode, + write_episode_stats, + write_info, write_json, write_parquet, ) @@ -90,6 +92,17 @@ class LeRobotDatasetMetadata: self.stats = load_stats(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root) self.episodes = load_episodes(self.root) + try: + self.episodes_stats = load_episodes_stats(self.root) + self.stats = aggregate_stats(list(self.episodes_stats.values())) + except FileNotFoundError: + logging.warning( + f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead. + Convert your dataset stats to the new format using this command: + python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """ + ) + self.stats = load_stats(self.root) + self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) def pull_from_repo( self, @@ -228,7 +241,13 @@ class LeRobotDatasetMetadata: } append_jsonlines(task_dict, self.root / TASKS_PATH) - def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None: + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + ) -> None: self.info["total_episodes"] += 1 self.info["total_frames"] += episode_length @@ -238,21 +257,19 @@ class LeRobotDatasetMetadata: self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) - write_json(self.info, self.root / INFO_PATH) + write_info(self.info, self.root) episode_dict = { "episode_index": episode_index, "tasks": episode_tasks, "length": episode_length, } - self.episodes.append(episode_dict) - append_jsonlines(episode_dict, self.root / EPISODES_PATH) + self.episodes[episode_index] = episode_dict + write_episode(episode_dict, self.root) - # TODO(aliberts): refactor stats in save_episodes - # image_sampling = int(self.fps / 2) # sample 2 img/s for the stats - # ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling) - # ep_stats = serialize_dict(ep_stats) - # append_jsonlines(ep_stats, self.root / STATS_PATH) + self.episodes_stats[episode_index] = episode_stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + write_episode_stats(episode_index, episode_stats, self.root) def write_video_info(self) -> None: """ @@ -309,6 +326,7 @@ class LeRobotDatasetMetadata: ) else: # TODO(aliberts, rcadene): implement sanity check for features + features = {**features, **DEFAULT_FEATURES} # check if none of the features contains a "/" in their names, # as this would break the dict flattening in the stats computation, which uses '/' as separator @@ -319,7 +337,7 @@ class LeRobotDatasetMetadata: features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.task_to_task_index = {}, {} - obj.stats, obj.episodes = {}, [] + obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() @@ -457,6 +475,9 @@ class LeRobotDataset(torch.utils.data.Dataset): # Load metadata self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) + if self.episodes is not None and self.meta._version == CODEBASE_VERSION: + episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] + self.stats = aggregate_stats(episodes_stats) # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) @@ -479,10 +500,13 @@ class LeRobotDataset(torch.utils.data.Dataset): def push_to_hub( self, + branch: str | None = None, + create_card: bool = True, tags: list | None = None, license: str | None = "apache-2.0", push_videos: bool = True, private: bool = False, + allow_patterns: list[str] | str | None = None, **card_kwargs, ) -> None: if not self.consolidated: @@ -496,24 +520,32 @@ class LeRobotDataset(torch.utils.data.Dataset): if not push_videos: ignore_patterns.append("videos/") - create_repo( + hub_api = HfApi() + hub_api.create_repo( repo_id=self.repo_id, private=private, repo_type="dataset", exist_ok=True, ) + if branch: + create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset") - upload_folder( + hub_api.upload_folder( repo_id=self.repo_id, folder_path=self.root, repo_type="dataset", + revision=branch, + allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - card = create_lerobot_dataset_card( - tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs - ) - card.push_to_hub(repo_id=self.repo_id, repo_type="dataset") - create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") + if create_card: + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + + if not branch: + create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") def pull_from_repo( self, @@ -630,7 +662,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if key not in self.meta.video_keys } - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. This probably happens because a memory reference to the video loader is created in @@ -660,8 +692,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices = None if self.delta_indices is not None: - current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx - query_indices, padding = self._get_query_indices(idx, current_ep_idx) + query_indices, padding = self._get_query_indices(idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): @@ -735,11 +766,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() + # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) + # Add frame features to episode_buffer for key in frame: if key == "task": # Note: we associate the task in natural language to its task index during `save_episode` @@ -787,7 +820,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts): Add option to use existing episode_index raise NotImplementedError( "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes in the dataset. This is not supported for now." + "match the total number of episodes already in the dataset. This is not supported for now." ) if episode_length == 0: @@ -821,8 +854,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) - - self.meta.save_episode(episode_index, episode_length, episode_tasks) + ep_stats = compute_episode_stats(episode_buffer, self.features) + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) if encode_videos and len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) @@ -908,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return video_paths - def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: + def consolidate(self, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) @@ -928,17 +961,7 @@ class LeRobotDataset(torch.utils.data.Dataset): parquet_files = list(self.root.rglob("*.parquet")) assert len(parquet_files) == self.num_episodes - if run_compute_stats: - self.stop_image_writer() - # TODO(aliberts): refactor stats in save_episodes - self.meta.stats = compute_stats(self) - serialized_stats = serialize_dict(self.meta.stats) - write_json(serialized_stats, self.root / STATS_PATH) - self.consolidated = True - else: - logging.warning( - "Skipping computation of the dataset statistics, dataset is not fully consolidated." - ) + self.consolidated = True @classmethod def create( @@ -1056,7 +1079,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps - self.stats = aggregate_stats(self._datasets) + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) @property def repo_id_to_index(self): diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 505a5492..8b734042 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -43,6 +43,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk INFO_PATH = "meta/info.json" EPISODES_PATH = "meta/episodes.jsonl" STATS_PATH = "meta/stats.json" +EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" @@ -113,7 +114,16 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: - serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()} + serialized_dict = {} + for key, value in flatten_dict(stats).items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + serialized_dict[key] = value.tolist() + elif isinstance(value, np.generic): + serialized_dict[key] = value.item() + elif isinstance(value, (int, float)): + serialized_dict[key] = value + else: + raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") return unflatten_dict(serialized_dict) @@ -154,6 +164,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None: writer.write(data) +def write_info(info: dict, local_dir: Path): + write_json(info, local_dir / INFO_PATH) + + def load_info(local_dir: Path) -> dict: info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): @@ -161,12 +175,29 @@ def load_info(local_dir: Path) -> dict: return info -def load_stats(local_dir: Path) -> dict: +def write_stats(stats: dict, local_dir: Path): + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) - stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} - return unflatten_dict(stats) + return cast_stats_to_numpy(stats) + + +def write_task(task_index: int, task: dict, local_dir: Path): + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, local_dir / TASKS_PATH) def load_tasks(local_dir: Path) -> dict: @@ -176,16 +207,42 @@ def load_tasks(local_dir: Path) -> dict: return tasks, task_to_task_index +def write_episode(episode: dict, local_dir: Path): + append_jsonlines(episode, local_dir / EPISODES_PATH) + + def load_episodes(local_dir: Path) -> dict: - return load_jsonlines(local_dir / EPISODES_PATH) + episodes = load_jsonlines(local_dir / EPISODES_PATH) + return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} -def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray: +def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): + # We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]` + # is a dictionary of stats and not an integer. + episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} + append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) + + +def load_episodes_stats(local_dir: Path) -> dict: + episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) + return { + item["episode_index"]: cast_stats_to_numpy(item["stats"]) + for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) + } + + +def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]: + return {ep_idx: stats for ep_idx in episodes} + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) img_array = np.transpose(img_array, (2, 0, 1)) - if "float" in dtype: + if np.issubdtype(dtype, np.floating): img_array /= 255.0 return img_array @@ -370,9 +427,9 @@ def create_empty_dataset_info( def get_episode_data_index( - episode_dicts: list[dict], episodes: list[int] | None = None + episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py new file mode 100644 index 00000000..0c5d2688 --- /dev/null +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -0,0 +1,87 @@ +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to +2.1. It performs the following: + +- Generates per-episodes stats and writes them in `episodes_stats.jsonl` +- Removes the deprecated `stats.json` (by default) +- Updates codebase_version in `info.json` + +Usage: + +```bash +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ + --repo-id=aliberts/koch_tutorial +``` + +""" +# TODO(rcadene, aliberts): ensure this script works for any other changes for the final v2.1 + +import argparse + +from huggingface_hub import HfApi + +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info +from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats + + +def main( + repo_id: str, + test_branch: str | None = None, + delete_old_stats: bool = False, + num_workers: int = 4, +): + dataset = LeRobotDataset(repo_id) + if (dataset.root / EPISODES_STATS_PATH).is_file(): + raise FileExistsError("episodes_stats.jsonl already exists.") + + convert_stats(dataset, num_workers=num_workers) + ref_stats = load_stats(dataset.root) + check_aggregate_stats(dataset, ref_stats) + + dataset.meta.info["codebase_version"] = CODEBASE_VERSION + write_info(dataset.meta.info, dataset.root) + + dataset.push_to_hub(branch=test_branch, create_card=False, allow_patterns="meta/") + + if delete_old_stats: + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() + hub_api = HfApi() + if hub_api.file_exists( + STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset" + ): + hub_api.delete_file( + STATS_PATH, repo_id=dataset.repo_id, revision=test_branch, repo_type="dataset" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--test-branch", + type=str, + default=None, + help="Repo branch to test your conversion first (e.g. 'v2.0.test')", + ) + parser.add_argument( + "--delete-old-stats", + type=bool, + default=False, + help="Delete the deprecated `stats.json`", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for parallelizing compute", + ) + + args = parser.parse_args() + main(**vars(args)) diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py new file mode 100644 index 00000000..b13e0e19 --- /dev/null +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -0,0 +1,85 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +from tqdm import tqdm + +from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import write_episode_stats + + +def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: + ep_len = dataset.meta.episodes[episode_index]["length"] + sampled_indices = sample_indices(ep_len) + query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) + video_frames = dataset._query_videos(query_timestamps, episode_index) + return video_frames[ft_key].numpy() + + +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): + ep_start_idx = dataset.episode_data_index["from"][ep_idx] + ep_end_idx = dataset.episode_data_index["to"][ep_idx] + ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) + + ep_stats = {} + for key, ft in dataset.features.items(): + if ft["dtype"] == "video": + # We sample only for videos + ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) + else: + ep_ft_data = np.array(ep_data[key]) + + axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 + keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 + ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) + + if ft["dtype"] in ["image", "video"]: # remove batch dim + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } + + dataset.meta.episodes_stats[ep_idx] = ep_stats + + +def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + if num_workers > 0: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx + for ep_idx in range(total_episodes) + } + for future in tqdm(as_completed(futures), total=total_episodes): + future.result() + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + +def check_aggregate_stats( + dataset: LeRobotDataset, + reference_stats: dict[str, dict[str, np.ndarray]], + video_rtol_atol: tuple[float] = (1e-2, 1e-2), + default_rtol_atol: tuple[float] = (5e-6, 0.0), +): + """Verifies that the aggregated stats from episodes_stats are close to reference stats.""" + agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) + for key, ft in dataset.features.items(): + # These values might need some fine-tuning + if ft["dtype"] == "video": + # to account for image sub-sampling + rtol, atol = video_rtol_atol + else: + rtol, atol = default_rtol_atol + + for stat, val in agg_stats[key].items(): + if key in reference_stats and stat in reference_stats[key]: + err_msg = f"feature='{key}' stats='{stat}'" + np.testing.assert_allclose( + val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg + ) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 8ed3318d..8be53483 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -69,8 +69,8 @@ def decode_video_frames_torchvision( # set the first and last requested timestamps # Note: previous timestamps are usually loaded, since we need to access the previous key frame - first_ts = timestamps[0] - last_ts = timestamps[-1] + first_ts = min(timestamps) + last_ts = max(timestamps) # access closest key frame of the first requested frame # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 95219273..b3255ec1 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import torch from torch import Tensor, nn @@ -77,17 +78,29 @@ def create_stats_buffers( } ) + # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) if stats: - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone() - buffer["std"].data = stats[key]["std"].clone() - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone() - buffer["max"].data = stats[key]["max"].clone() + if isinstance(stats[key]["mean"], np.ndarray): + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) + buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) + buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) + else: + type_ = type(stats[key]["mean"]) + raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") stats_buffers[key] = buffer return stats_buffers @@ -141,6 +154,7 @@ class Normalize(nn.Module): batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: + # FIXME(aliberts, rcadene): This might lead to silent fail! continue norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index a2f3889c..c96a87f0 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -60,8 +60,6 @@ class RecordControlConfig(ControlConfig): num_episodes: int = 50 # Encode frames in the dataset into video video: bool = True - # By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode. - run_compute_stats: bool = True # Upload dataset to Hugging Face hub. push_to_hub: bool = True # Upload on private repository on the Hugging Face hub. diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index de67e331..5f51c81b 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -301,10 +301,7 @@ def record( log_say("Stop recording", cfg.play_sounds, blocking=True) stop_recording(robot, listener, cfg.display_cameras) - if cfg.run_compute_stats: - logging.info("Computing dataset statistics") - - dataset.consolidate(cfg.run_compute_stats) + dataset.consolidate() if cfg.push_to_hub: dataset.push_to_hub(tags=cfg.tags, private=cfg.private) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index bdd2dc54..f57f945b 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -29,7 +29,7 @@ from tests.fixtures.constants import ( def get_task_index(task_dicts: dict, task: str) -> int: - tasks = {d["task_index"]: d["task"] for d in task_dicts} + tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} return task_to_task_index[task] @@ -142,6 +142,7 @@ def stats_factory(): "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), + "count": [10], } else: stats[key] = { @@ -149,20 +150,38 @@ def stats_factory(): "mean": np.full(shape, 0.5, dtype=dtype).tolist(), "min": np.full(shape, 0, dtype=dtype).tolist(), "std": np.full(shape, 0.25, dtype=dtype).tolist(), + "count": [10], } return stats return _create_stats +@pytest.fixture(scope="session") +def episodes_stats_factory(stats_factory): + def _create_episodes_stats( + features: dict[str], + total_episodes: int = 3, + ) -> dict: + episodes_stats = {} + for episode_index in range(total_episodes): + episodes_stats[episode_index] = { + "episode_index": episode_index, + "stats": stats_factory(features), + } + return episodes_stats + + return _create_episodes_stats + + @pytest.fixture(scope="session") def tasks_factory(): def _create_tasks(total_tasks: int = 3) -> int: - tasks_list = [] - for i in range(total_tasks): - task_dict = {"task_index": i, "task": f"Perform action {i}."} - tasks_list.append(task_dict) - return tasks_list + tasks = {} + for task_index in range(total_tasks): + task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} + tasks[task_index] = task_dict + return tasks return _create_tasks @@ -191,10 +210,10 @@ def episodes_factory(tasks_factory): # Generate random lengths that sum up to total_length lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() - tasks_list = [task_dict["task"] for task_dict in tasks] + tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks_list) - episodes_list = [] + episodes = {} remaining_tasks = tasks_list.copy() for ep_idx in range(total_episodes): num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 @@ -204,15 +223,13 @@ def episodes_factory(tasks_factory): for task in episode_tasks: remaining_tasks.remove(task) - episodes_list.append( - { - "episode_index": ep_idx, - "tasks": episode_tasks, - "length": lengths[ep_idx], - } - ) + episodes[ep_idx] = { + "episode_index": ep_idx, + "tasks": episode_tasks, + "length": lengths[ep_idx], + } - return episodes_list + return episodes return _create_episodes @@ -236,7 +253,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar frame_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) - for ep_dict in episodes: + for ep_dict in episodes.values(): timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( @@ -279,6 +296,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar def lerobot_dataset_metadata_factory( info_factory, stats_factory, + episodes_stats_factory, tasks_factory, episodes_factory, mock_snapshot_download_factory, @@ -288,6 +306,7 @@ def lerobot_dataset_metadata_factory( repo_id: str = DUMMY_REPO_ID, info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, local_files_only: bool = False, @@ -296,6 +315,10 @@ def lerobot_dataset_metadata_factory( info = info_factory() if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory( + features=info["features"], total_episodes=info["total_episodes"] + ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: @@ -306,6 +329,7 @@ def lerobot_dataset_metadata_factory( mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episodes, ) @@ -329,6 +353,7 @@ def lerobot_dataset_metadata_factory( def lerobot_dataset_factory( info_factory, stats_factory, + episodes_stats_factory, tasks_factory, episodes_factory, hf_dataset_factory, @@ -344,6 +369,7 @@ def lerobot_dataset_factory( multi_task: bool = False, info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episode_dicts: list[dict] | None = None, hf_dataset: datasets.Dataset | None = None, @@ -355,6 +381,8 @@ def lerobot_dataset_factory( ) if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episode_dicts: @@ -370,6 +398,7 @@ def lerobot_dataset_factory( mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, hf_dataset=hf_dataset, @@ -379,6 +408,7 @@ def lerobot_dataset_factory( repo_id=repo_id, info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, local_files_only=kwargs.get("local_files_only", False), @@ -406,7 +436,7 @@ def empty_lerobot_dataset_factory(): robot: Robot | None = None, robot_type: str | None = None, features: dict | None = None, - ): + ) -> LeRobotDataset: return LeRobotDataset.create( repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features ) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 5fe8a314..4ef12e49 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -7,7 +7,13 @@ import pyarrow.compute as pc import pyarrow.parquet as pq import pytest -from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH +from lerobot.common.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, +) @pytest.fixture(scope="session") @@ -38,6 +44,20 @@ def stats_path(stats_factory): return _create_stats_json_file +@pytest.fixture(scope="session") +def episodes_stats_path(episodes_stats_factory): + def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: + if not episodes_stats: + episodes_stats = episodes_stats_factory() + fpath = dir / EPISODES_STATS_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(episodes_stats.values()) + return fpath + + return _create_episodes_stats_jsonl_file + + @pytest.fixture(scope="session") def tasks_path(tasks_factory): def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: @@ -46,7 +66,7 @@ def tasks_path(tasks_factory): fpath = dir / TASKS_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(tasks) + writer.write_all(tasks.values()) return fpath return _create_tasks_jsonl_file @@ -60,7 +80,7 @@ def episode_path(episodes_factory): fpath = dir / EPISODES_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes) + writer.write_all(episodes.values()) return fpath return _create_episodes_jsonl_file diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 351768c0..ae309cb4 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -4,7 +4,13 @@ import datasets import pytest from huggingface_hub.utils import filter_repo_objects -from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH +from lerobot.common.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, +) from tests.fixtures.constants import LEROBOT_TEST_DIR @@ -14,6 +20,8 @@ def mock_snapshot_download_factory( info_path, stats_factory, stats_path, + episodes_stats_factory, + episodes_stats_path, tasks_factory, tasks_path, episodes_factory, @@ -29,6 +37,7 @@ def mock_snapshot_download_factory( def _mock_snapshot_download_func( info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, hf_dataset: datasets.Dataset | None = None, @@ -37,6 +46,10 @@ def mock_snapshot_download_factory( info = info_factory() if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory( + features=info["features"], total_episodes=info["total_episodes"] + ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: @@ -67,11 +80,11 @@ def mock_snapshot_download_factory( # List all possible files all_files = [] - meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH] + meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] all_files.extend(meta_files) data_files = [] - for episode_dict in episodes: + for episode_dict in episodes.values(): ep_idx = episode_dict["episode_index"] ep_chunk = ep_idx // info["chunks_size"] data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) @@ -92,6 +105,8 @@ def mock_snapshot_download_factory( _ = info_path(local_dir, info) elif rel_path == STATS_PATH: _ = stats_path(local_dir, stats) + elif rel_path == EPISODES_STATS_PATH: + _ = episodes_stats_path(local_dir, episodes_stats) elif rel_path == TASKS_PATH: _ = tasks_path(local_dir, tasks) elif rel_path == EPISODES_PATH: diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 1a1812f7..7c043c25 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -182,7 +182,7 @@ def test_camera(request, camera_type, mock): @pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) @require_camera -def test_save_images_from_cameras(tmpdir, request, camera_type, mock): +def test_save_images_from_cameras(tmp_path, request, camera_type, mock): # TODO(rcadene): refactor if camera_type == "opencv": from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras @@ -190,4 +190,4 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock): from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras # Small `record_time_s` to speedup unit tests - save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock) + save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) diff --git a/tests/test_compute_stats.py b/tests/test_compute_stats.py new file mode 100644 index 00000000..d9032c8a --- /dev/null +++ b/tests/test_compute_stats.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import numpy as np +import pytest + +from lerobot.common.datasets.compute_stats import ( + _assert_type_and_shape, + aggregate_feature_stats, + aggregate_stats, + compute_episode_stats, + estimate_num_samples, + get_feature_stats, + sample_images, + sample_indices, +) + + +def mock_load_image_as_numpy(path, dtype, channel_first): + return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + + +@pytest.fixture +def sample_array(): + return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +def test_estimate_num_samples(): + assert estimate_num_samples(1) == 1 + assert estimate_num_samples(10) == 10 + assert estimate_num_samples(100) == 100 + assert estimate_num_samples(200) == 100 + assert estimate_num_samples(1000) == 177 + assert estimate_num_samples(2000) == 299 + assert estimate_num_samples(5000) == 594 + assert estimate_num_samples(10_000) == 1000 + assert estimate_num_samples(20_000) == 1681 + assert estimate_num_samples(50_000) == 3343 + assert estimate_num_samples(500_000) == 10_000 + + +def test_sample_indices(): + indices = sample_indices(10) + assert len(indices) > 0 + assert indices[0] == 0 + assert indices[-1] == 9 + assert len(indices) == estimate_num_samples(10) + + +@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy) +def test_sample_images(mock_load): + image_paths = [f"image_{i}.jpg" for i in range(100)] + images = sample_images(image_paths) + assert isinstance(images, np.ndarray) + assert images.shape[1:] == (3, 32, 32) + assert images.dtype == np.uint8 + assert len(images) == estimate_num_samples(100) + + +def test_get_feature_stats_images(): + data = np.random.rand(100, 3, 32, 32) + stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + np.testing.assert_equal(stats["count"], np.array([100])) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + + +def test_get_feature_stats_axis_0_keepdims(sample_array): + expected = { + "min": np.array([[1, 2, 3]]), + "max": np.array([[7, 8, 9]]), + "mean": np.array([[4.0, 5.0, 6.0]]), + "std": np.array([[2.44948974, 2.44948974, 2.44948974]]), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=(0,), keepdims=True) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_axis_1(sample_array): + expected = { + "min": np.array([1, 4, 7]), + "max": np.array([3, 6, 9]), + "mean": np.array([2.0, 5.0, 8.0]), + "std": np.array([0.81649658, 0.81649658, 0.81649658]), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=(1,), keepdims=False) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_no_axis(sample_array): + expected = { + "min": np.array(1), + "max": np.array(9), + "mean": np.array(5.0), + "std": np.array(2.5819889), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=None, keepdims=False) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_empty_array(): + array = np.array([]) + with pytest.raises(ValueError): + get_feature_stats(array, axis=(0,), keepdims=True) + + +def test_get_feature_stats_single_value(): + array = np.array([[1337]]) + result = get_feature_stats(array, axis=None, keepdims=True) + np.testing.assert_equal(result["min"], np.array(1337)) + np.testing.assert_equal(result["max"], np.array(1337)) + np.testing.assert_equal(result["mean"], np.array(1337.0)) + np.testing.assert_equal(result["std"], np.array(0.0)) + np.testing.assert_equal(result["count"], np.array([1])) + + +def test_compute_episode_stats(): + episode_data = { + "observation.image": [f"image_{i}.jpg" for i in range(100)], + "observation.state": np.random.rand(100, 10), + } + features = { + "observation.image": {"dtype": "image"}, + "observation.state": {"dtype": "numeric"}, + } + + with patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ): + stats = compute_episode_stats(episode_data, features) + + assert "observation.image" in stats and "observation.state" in stats + assert stats["observation.image"]["count"].item() == 100 + assert stats["observation.state"]["count"].item() == 100 + assert stats["observation.image"]["mean"].shape == (3, 1, 1) + + +def test_assert_type_and_shape_valid(): + valid_stats = [ + { + "feature1": { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + } + } + ] + _assert_type_and_shape(valid_stats) + + +def test_assert_type_and_shape_invalid_type(): + invalid_stats = [ + { + "feature1": { + "min": [1.0], # Not a numpy array + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + } + } + ] + with pytest.raises(ValueError, match="Stats must be composed of numpy array"): + _assert_type_and_shape(invalid_stats) + + +def test_assert_type_and_shape_invalid_shape(): + invalid_stats = [ + { + "feature1": { + "count": np.array([1, 2]), # Wrong shape + } + } + ] + with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"): + _assert_type_and_shape(invalid_stats) + + +def test_aggregate_feature_stats(): + stats_ft_list = [ + { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + }, + { + "min": np.array([2.0]), + "max": np.array([12.0]), + "mean": np.array([6.0]), + "std": np.array([2.5]), + "count": np.array([1]), + }, + ] + result = aggregate_feature_stats(stats_ft_list) + np.testing.assert_allclose(result["min"], np.array([1.0])) + np.testing.assert_allclose(result["max"], np.array([12.0])) + np.testing.assert_allclose(result["mean"], np.array([5.5])) + np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6) + np.testing.assert_allclose(result["count"], np.array([2])) + + +def test_aggregate_stats(): + all_stats = [ + { + "observation.image": { + "min": [1, 2, 3], + "max": [10, 20, 30], + "mean": [5.5, 10.5, 15.5], + "std": [2.87, 5.87, 8.87], + "count": 10, + }, + "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, + }, + { + "observation.image": { + "min": [2, 1, 0], + "max": [15, 10, 5], + "mean": [8.5, 5.5, 2.5], + "std": [3.42, 2.42, 1.42], + "count": 15, + }, + "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, + }, + ] + + expected_agg_stats = { + "observation.image": { + "min": [1, 1, 0], + "max": [15, 20, 30], + "mean": [7.3, 7.5, 7.7], + "std": [3.5317, 4.8267, 8.5581], + "count": 25, + }, + "observation.state": { + "min": 1, + "max": 15, + "mean": 7.3, + "std": 3.5317, + "count": 25, + }, + "extra_key_0": { + "min": 5, + "max": 25, + "mean": 15.0, + "std": 6.0, + "count": 6, + }, + "extra_key_1": { + "min": 0, + "max": 20, + "mean": 10.0, + "std": 5.0, + "count": 5, + }, + } + + # cast to numpy + for ep_stats in all_stats: + for fkey, stats in ep_stats.items(): + for k in stats: + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + if fkey == "observation.image" and k != "count": + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + else: + stats[k] = stats[k].reshape(1) + + # cast to numpy + for fkey, stats in expected_agg_stats.items(): + for k in stats: + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + if fkey == "observation.image" and k != "count": + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + else: + stats[k] = stats[k].reshape(1) + + results = aggregate_stats(all_stats) + + for fkey in expected_agg_stats: + np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) + np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) + np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) + np.testing.assert_allclose( + results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 + ) + np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 36ee096f..a4f538a6 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -24,7 +24,6 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]' """ import multiprocessing -from pathlib import Path from unittest.mock import patch import pytest @@ -45,7 +44,7 @@ from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_ @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_teleoperate(tmpdir, request, robot_type, mock): +def test_teleoperate(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -53,8 +52,7 @@ def test_teleoperate(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -70,15 +68,14 @@ def test_teleoperate(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_calibrate(tmpdir, request, robot_type, mock): +def test_calibrate(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: request.getfixturevalue("patch_builtins_input") # Create an empty calibration directory to trigger manual calibration - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type robot_kwargs["calibration_dir"] = calibration_dir robot = make_robot(**robot_kwargs) @@ -89,7 +86,7 @@ def test_calibrate(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_record_without_cameras(tmpdir, request, robot_type, mock): +def test_record_without_cameras(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} # Avoid using cameras @@ -100,7 +97,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = Path(tmpdir) / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -108,7 +105,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): pass repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." robot = make_robot(**robot_kwargs) @@ -121,7 +118,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): episode_time_s=1, reset_time_s=0.1, num_episodes=2, - run_compute_stats=False, push_to_hub=False, video=False, play_sounds=False, @@ -131,8 +127,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): - tmpdir = Path(tmpdir) +def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -140,7 +135,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -148,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): pass repo_id = "lerobot_test/debug" - root = tmpdir / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." robot = make_robot(**robot_kwargs) @@ -180,7 +175,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): policy_cfg = ACTConfig() policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) - out_dir = tmpdir / "logger" + out_dir = tmp_path / "logger" pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model" policy.save_pretrained(pretrained_policy_path) @@ -207,7 +202,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): num_image_writer_processes = 0 eval_repo_id = "lerobot/eval_debug" - eval_root = tmpdir / "data" / eval_repo_id + eval_root = tmp_path / "data" / eval_repo_id rec_eval_cfg = RecordControlConfig( repo_id=eval_repo_id, @@ -218,7 +213,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): episode_time_s=1, reset_time_s=0.1, num_episodes=2, - run_compute_stats=False, push_to_hub=False, video=False, display_cameras=False, @@ -240,7 +234,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_resume_record(tmpdir, request, robot_type, mock): +def test_resume_record(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -248,7 +242,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -258,7 +252,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): robot = make_robot(**robot_kwargs) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -272,7 +266,6 @@ def test_resume_record(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, local_files_only=True, num_episodes=1, ) @@ -291,7 +284,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): +def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -299,7 +292,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -316,7 +309,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -331,7 +324,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, ) dataset = record(robot, rec_cfg) @@ -342,7 +334,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): +def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -350,7 +342,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -367,7 +359,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -382,7 +374,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, ) dataset = record(robot, rec_cfg) @@ -395,7 +386,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] ) @require_robot -def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes): +def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -403,7 +394,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -420,7 +411,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -436,7 +427,6 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, num_image_writer_processes=num_image_writer_processes, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 54d92125..b1df9b46 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -20,21 +20,14 @@ from copy import deepcopy from itertools import chain from pathlib import Path -import einops import numpy as np import pytest import torch -from datasets import Dataset from huggingface_hub import HfApi from PIL import Image from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.compute_stats import ( - aggregate_stats, - compute_stats, - get_stats_einops_patterns, -) from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.image_writer import image_array_to_pil_image from lerobot.common.datasets.lerobot_dataset import ( @@ -44,13 +37,11 @@ from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.utils import ( create_branch, flatten_dict, - hf_transform_to_torch, unflatten_dict, ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot -from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -196,12 +187,12 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(1), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert len(dataset) == 1 - assert dataset[0]["task"] == "dummy" + assert dataset[0]["task"] == "Dummy task" assert dataset[0]["task_index"] == 0 assert dataset[0]["state"].ndim == 0 @@ -209,9 +200,9 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].shape == torch.Size([2]) @@ -219,9 +210,9 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -229,9 +220,9 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -239,9 +230,9 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -249,9 +240,9 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -261,7 +252,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["state"].ndim == 0 @@ -271,7 +262,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["caption"] == "Dummy caption" @@ -307,7 +298,7 @@ def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -316,7 +307,7 @@ def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -326,7 +317,7 @@ def test_add_frame_image_uint8(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -336,7 +327,7 @@ def test_add_frame_image_pil(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.save_episode(encode_videos=False) - dataset.consolidate(run_compute_stats=False) + dataset.consolidate() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -463,67 +454,6 @@ def test_multidataset_frames(): assert torch.equal(sub_dataset_item[k], dataset_item[k]) -# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py -def test_compute_stats_on_xarm(): - """Check that the statistics are computed correctly according to the stats_patterns property. - - We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do - because we are working with a small dataset). - """ - # TODO(rcadene, aliberts): remove dataset download - dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0]) - - # reduce size of dataset sample on which stats compute is tested to 10 frames - dataset.hf_dataset = dataset.hf_dataset.select(range(10)) - - # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched - # computation of the statistics. While doing this, we also make sure it works when we don't divide the - # dataset into even batches. - computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0) - - # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) - - # get all frames from the dataset in the same dtype and range as during compute_stats - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=len(dataset), - shuffle=False, - ) - full_batch = next(iter(dataloader)) - - # compute stats based on all frames from the dataset without any batching - expected_stats = {} - for k, pattern in stats_patterns.items(): - full_batch[k] = full_batch[k].float() - expected_stats[k] = {} - expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") - expected_stats[k]["std"] = torch.sqrt( - einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") - ) - expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") - expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") - - # test computed stats match expected stats - for k in stats_patterns: - assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) - assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) - assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) - assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) - - # load stats used during training which are expected to match the ones returned by computed_stats - loaded_stats = dataset.meta.stats # noqa: F841 - - # TODO(rcadene): we can't test this because expected_stats is computed on a subset - # # test loaded stats match expected stats - # for k in stats_patterns: - # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) - # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) - # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) - # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) - - # TODO(aliberts): Move to more appropriate location def test_flatten_unflatten_dict(): d = { @@ -627,35 +557,6 @@ def test_backward_compatibility(repo_id): # load_and_compare(i - 1) -@pytest.mark.skip("TODO after fix multidataset") -def test_multidataset_aggregate_stats(): - """Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" - with seeded_context(0): - data_a = torch.rand(30, dtype=torch.float32) - data_b = torch.rand(20, dtype=torch.float32) - data_c = torch.rand(20, dtype=torch.float32) - - hf_dataset_1 = Dataset.from_dict( - {"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)} - ) - hf_dataset_1.set_transform(hf_transform_to_torch) - hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}) - hf_dataset_2.set_transform(hf_transform_to_torch) - hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}) - hf_dataset_3.set_transform(hf_transform_to_torch) - dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1) - dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0) - dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2) - dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0) - dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3) - dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0) - stats = aggregate_stats([dataset_1, dataset_2, dataset_3]) - for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): - for agg_fn in ["mean", "min", "max"]: - assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) - assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0)) - - @pytest.mark.skip("Requires internet access") def test_create_branch(): api = HfApi() diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py deleted file mode 100644 index a0c8d908..00000000 --- a/tests/test_push_dataset_to_hub.py +++ /dev/null @@ -1,370 +0,0 @@ -""" -This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API. -Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets, -we skip them for now in our CI. - -Example to run backward compatiblity tests locally: -``` -python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility -``` -""" - -from pathlib import Path - -import numpy as np -import pytest -import torch - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently -from lerobot.common.datasets.video_utils import encode_video_frames -from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub -from tests.utils import require_package_arg - - -def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3): - import zarr - - raw_dir.mkdir(parents=True, exist_ok=True) - zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" - store = zarr.DirectoryStore(zarr_path) - zarr_data = zarr.group(store=store) - - zarr_data.create_dataset( - "data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/img", - shape=(num_frames, 96, 96, 3), - chunks=(num_frames, 96, 96, 3), - dtype=np.uint8, - overwrite=True, - ) - zarr_data.create_dataset( - "data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True - ) - - zarr_data["data/action"][:] = np.random.randn(num_frames, 1) - zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) - zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2) - zarr_data["data/state"][:] = np.random.randn(num_frames, 5) - zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2) - zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) - - store.close() - - -def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3): - import zarr - - raw_dir.mkdir(parents=True, exist_ok=True) - zarr_path = raw_dir / "cup_in_the_wild.zarr" - store = zarr.DirectoryStore(zarr_path) - zarr_data = zarr.group(store=store) - - zarr_data.create_dataset( - "data/camera0_rgb", - shape=(num_frames, 96, 96, 3), - chunks=(num_frames, 96, 96, 3), - dtype=np.uint8, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_demo_end_pose", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_demo_start_pose", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/robot0_eef_rot_axis_angle", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_gripper_width", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True - ) - - zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) - zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5) - zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) - - store.close() - - -def _mock_download_raw_xarm(raw_dir, num_frames=4): - import pickle - - dataset_dict = { - "observations": { - "rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8), - "state": np.random.randn(num_frames, 4), - }, - "actions": np.random.randn(num_frames, 3), - "rewards": np.random.randn(num_frames), - "masks": np.random.randn(num_frames), - "dones": np.array([False, True, True, True]), - } - - raw_dir.mkdir(parents=True, exist_ok=True) - pkl_path = raw_dir / "buffer.pkl" - with open(pkl_path, "wb") as f: - pickle.dump(dataset_dict, f) - - -def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3): - import h5py - - for ep_idx in range(num_episodes): - raw_dir.mkdir(parents=True, exist_ok=True) - path_h5 = raw_dir / f"episode_{ep_idx}.hdf5" - with h5py.File(str(path_h5), "w") as f: - f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset( - "observations/images/top", - data=np.random.randint( - 0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8 - ), - ) - - -def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30): - from datetime import datetime, timedelta, timezone - - import pandas - - def write_parquet(key, timestamps, values): - data = { - "timestamp_utc": timestamps, - key: values, - } - df = pandas.DataFrame(data) - raw_dir.mkdir(parents=True, exist_ok=True) - df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow") - - episode_indices = [None, None, -1, None, None, -1, None, None, -1] - episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2] - frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1] - - cam_key = "observation.images.cam_high" - timestamps = [] - actions = [] - states = [] - frames = [] - # `+ num_episodes`` for buffer frames associated to episode_index=-1 - for i, frame_idx in enumerate(frame_indices): - t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps) - action = np.random.randn(21).tolist() - state = np.random.randn(21).tolist() - ep_idx = episode_indices_mapping[i] - frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}] - timestamps.append(t_utc) - actions.append(action) - states.append(state) - frames.append(frame) - - write_parquet(cam_key, timestamps, frames) - write_parquet("observation.state", timestamps, states) - write_parquet("action", timestamps, actions) - write_parquet("episode_index", timestamps, episode_indices) - - # write fake mp4 file for each episode - for ep_idx in range(num_episodes): - imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8) - - tmp_imgs_dir = raw_dir / "tmp_images" - save_images_concurrently(imgs_array, tmp_imgs_dir) - - fname = f"{cam_key}_episode_{ep_idx:06d}.mp4" - video_path = raw_dir / "videos" / fname - encode_video_frames(tmp_imgs_dir, video_path, fps, vcodec="libx264") - - -def _mock_download_raw(raw_dir, repo_id): - if "wrist_gripper" in repo_id: - _mock_download_raw_dora(raw_dir) - elif "aloha" in repo_id: - _mock_download_raw_aloha(raw_dir) - elif "pusht" in repo_id: - _mock_download_raw_pusht(raw_dir) - elif "xarm" in repo_id: - _mock_download_raw_xarm(raw_dir) - elif "umi" in repo_id: - _mock_download_raw_umi(raw_dir) - else: - raise ValueError(repo_id) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -def test_push_dataset_to_hub_invalid_repo_id(tmpdir): - with pytest.raises(ValueError): - push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id") - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): - tmpdir = Path(tmpdir) - out_dir = tmpdir / "out" - raw_dir = tmpdir / "raw" - # mkdir to skip download - raw_dir.mkdir(parents=True, exist_ok=True) - with pytest.raises(ValueError): - push_dataset_to_hub( - raw_dir=raw_dir, - raw_format="some_format", - repo_id="user/dataset", - local_dir=out_dir, - force_override=False, - ) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -@pytest.mark.parametrize( - "required_packages, raw_format, repo_id, make_test_data", - [ - (["gym_pusht"], "pusht_zarr", "lerobot/pusht", False), - (["gym_pusht"], "pusht_zarr", "lerobot/pusht", True), - (None, "xarm_pkl", "lerobot/xarm_lift_medium", False), - (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False), - (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False), - (None, "dora_parquet", "cadene/wrist_gripper", False), - ], -) -@require_package_arg -def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data): - num_episodes = 3 - tmpdir = Path(tmpdir) - - raw_dir = tmpdir / f"{repo_id}_raw" - _mock_download_raw(raw_dir, repo_id) - - local_dir = tmpdir / repo_id - - lerobot_dataset = push_dataset_to_hub( - raw_dir=raw_dir, - raw_format=raw_format, - repo_id=repo_id, - push_to_hub=False, - local_dir=local_dir, - force_override=False, - cache_dir=tmpdir / "cache", - tests_data_dir=tmpdir / "tests/data" if make_test_data else None, - encoding={"vcodec": "libx264"}, - ) - - # minimal generic tests on the local directory containing LeRobotDataset - assert (local_dir / "meta_data" / "info.json").exists() - assert (local_dir / "meta_data" / "stats.safetensors").exists() - assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists() - for i in range(num_episodes): - for cam_key in lerobot_dataset.camera_keys: - assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists() - assert (local_dir / "train" / "dataset_info.json").exists() - assert (local_dir / "train" / "state.json").exists() - assert len(list((local_dir / "train").glob("*.arrow"))) > 0 - - # minimal generic tests on the item - item = lerobot_dataset[0] - assert "index" in item - assert "episode_index" in item - assert "timestamp" in item - for cam_key in lerobot_dataset.camera_keys: - assert cam_key in item - - if make_test_data: - # Check that only the first episode is selected. - test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data") - num_frames = sum( - i == lerobot_dataset.hf_dataset["episode_index"][0] - for i in lerobot_dataset.hf_dataset["episode_index"] - ).item() - assert ( - test_dataset.hf_dataset["episode_index"] - == lerobot_dataset.hf_dataset["episode_index"][:num_frames] - ) - for k in ["from", "to"]: - assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1]) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -@pytest.mark.parametrize( - "raw_format, repo_id", - [ - # TODO(rcadene): add raw dataset test artifacts - ("pusht_zarr", "lerobot/pusht"), - ("xarm_pkl", "lerobot/xarm_lift_medium"), - ("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), - ("umi_zarr", "lerobot/umi_cup_in_the_wild"), - ("dora_parquet", "cadene/wrist_gripper"), - ], -) -def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id): - _, dataset_id = repo_id.split("/") - - tmpdir = Path(tmpdir) - raw_dir = tmpdir / f"{dataset_id}_raw" - local_dir = tmpdir / repo_id - - push_dataset_to_hub( - raw_dir=raw_dir, - raw_format=raw_format, - repo_id=repo_id, - push_to_hub=False, - local_dir=local_dir, - force_override=False, - cache_dir=tmpdir / "cache", - episodes=[0], - ) - - ds_actual = LeRobotDataset(repo_id, root=tmpdir) - ds_reference = LeRobotDataset(repo_id) - - assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset) - - def check_same_items(item1, item2): - assert item1.keys() == item2.keys(), "Keys mismatch" - - for key in item1: - if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor): - assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}" - else: - assert item1[key] == item2[key], f"Mismatch found in key: {key}" - - for i in range(len(ds_reference.hf_dataset)): - item_reference = ds_reference.hf_dataset[i] - item_actual = ds_actual.hf_dataset[i] - check_same_items(item_reference, item_actual) diff --git a/tests/test_robots.py b/tests/test_robots.py index e03b5f78..6c300b71 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -23,8 +23,6 @@ pytest -sx 'tests/test_robots.py::test_robot[aloha-True]' ``` """ -from pathlib import Path - import pytest import torch @@ -35,7 +33,7 @@ from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_robot(tmpdir, request, robot_type, mock): +def test_robot(tmp_path, request, robot_type, mock): # TODO(rcadene): measure fps in nightly? # TODO(rcadene): test logs # TODO(rcadene): add compatibility with other robots @@ -50,8 +48,7 @@ def test_robot(tmpdir, request, robot_type, mock): request.getfixturevalue("patch_builtins_input") # Create an empty calibration directory to trigger manual calibration - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir