diff --git a/.gitignore b/.gitignore index 5b73b9ad..4ccf404d 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,6 @@ celerybeat.pid # Environments .env .venv -env/ venv/ ENV/ env.bak/ diff --git a/lerobot/__init__.py b/lerobot/__init__.py index e0234f29..a5a90fb4 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -45,6 +45,9 @@ import itertools from lerobot.__version__ import __version__ # noqa: F401 +# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies` +# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to +# a yaml file AND a environment name. The difference should be more obvious. available_tasks_per_env = { "aloha": [ "AlohaInsertion-v0", @@ -52,6 +55,7 @@ available_tasks_per_env = { ], "pusht": ["PushT-v0"], "xarm": ["XarmLift-v0"], + "dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -77,6 +81,23 @@ available_datasets_per_env = { "lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_replay_image", ], + "dora_aloha_real": [ + "lerobot/aloha_static_battery", + "lerobot/aloha_static_candy", + "lerobot/aloha_static_coffee", + "lerobot/aloha_static_coffee_new", + "lerobot/aloha_static_cups_open", + "lerobot/aloha_static_fork_pick_up", + "lerobot/aloha_static_pingpong_test", + "lerobot/aloha_static_pro_pencil", + "lerobot/aloha_static_screw_driver", + "lerobot/aloha_static_tape", + "lerobot/aloha_static_thread_velcro", + "lerobot/aloha_static_towel", + "lerobot/aloha_static_vinh_cup", + "lerobot/aloha_static_vinh_cup_left", + "lerobot/aloha_static_ziploc_slide", + ], } available_real_world_datasets = [ @@ -108,16 +129,19 @@ available_datasets = list( itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets) ) +# lists all available policies from `lerobot/common/policies` by their class attribute: `name`. available_policies = [ "act", "diffusion", "tdmpc", ] +# keys and values refer to yaml files available_policies_per_env = { "aloha": ["act"], "pusht": ["diffusion"], "xarm": ["tdmpc"], + "dora_aloha_real": ["act_real"], } env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/compute_stats.py similarity index 70% rename from lerobot/common/datasets/push_dataset_to_hub/compute_stats.py rename to lerobot/common/datasets/compute_stats.py index ec296658..a69bc573 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -16,17 +16,15 @@ from copy import deepcopy from math import ceil -import datasets import einops import torch import tqdm from datasets import Image -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.video_utils import VideoFrame -def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0): +def get_stats_einops_patterns(dataset, num_workers=0): """These einops patterns will be used to aggregate batches and compute statistics. Note: We assume the images are in channel first format @@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo return stats_patterns -def compute_stats( - dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None -): +def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None): + """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" if max_num_samples is None: max_num_samples = len(dataset) @@ -159,3 +156,54 @@ def compute_stats( "min": min[key], } return stats + + +def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: + """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. + + The final stats will have the union of all data keys from each of the datasets. + + The final stats will have the union of all data keys from each of the datasets. For instance: + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_min = min(min_dataset_0, min_dataset_1, ...) + - new_mean = (mean of all data) + - new_std = (std of all data) + """ + data_keys = set() + for dataset in ls_datasets: + data_keys.update(dataset.stats.keys()) + stats = {k: {} for k in data_keys} + for data_key in data_keys: + for stat_key in ["min", "max"]: + # compute `max(dataset_0["max"], dataset_1["max"], ...)` + stats[data_key][stat_key] = einops.reduce( + torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0), + "n ... -> ...", + stat_key, + ) + total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats) + # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective + # dataset, then divide by total_samples to get the overall "mean". + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["mean"] = sum( + d.stats[data_key]["mean"] * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + # The derivation for standard deviation is a little more involved but is much in the same spirit as + # the computation of the mean. + # Given two sets of data where the statistics are known: + # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] + # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined + # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of + # numerical overflow! + stats[data_key]["std"] = torch.sqrt( + sum( + (d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2) + * (d.num_samples / total_samples) + for d in ls_datasets + if data_key in d.stats + ) + ) + return stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 0bd09ac4..d9ce7a0f 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -16,10 +16,10 @@ import logging import torch +from omegaconf import ListConfig, OmegaConf from torchvision.transforms import v2 -from omegaconf import OmegaConf -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset def resolve_delta_timestamps(cfg): @@ -36,32 +36,73 @@ def resolve_delta_timestamps(cfg): cfg.training.delta_timestamps[key] = eval(delta_timestamps[key]) -def make_dataset( - cfg, - split="train", -): - if cfg.env.name not in cfg.dataset_repo_id: - logging.warning( - f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " - f"environment ({cfg.env.name=})." +def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: + """ + Args: + cfg: A Hydra config as per the LeRobot config scheme. + split: Select the data subset used to create an instance of LeRobotDataset. + All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train". + Thus, by default, `split="train"` selects all the available data. `split` aims to work like the + slicer in the hugging face datasets: + https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits + As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or + `split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`. + Returns: + The LeRobotDataset. + """ + if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): + raise ValueError( + "Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " + "strings to load multiple datasets." ) + # A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora). + if cfg.env.name != "dora": + if isinstance(cfg.dataset_repo_id, str): + dataset_repo_ids = [cfg.dataset_repo_id] # single dataset + else: + dataset_repo_ids = cfg.dataset_repo_id # multiple datasets + + for dataset_repo_id in dataset_repo_ids: + if cfg.env.name not in dataset_repo_id: + logging.warning( + f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your " + f"environment ({cfg.env.name=})." + ) + resolve_delta_timestamps(cfg) if cfg.image_transform.enable: - transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor), - v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p), - v2.ToDtype(torch.float32, scale=True), - ]) + transform = v2.Compose( + [ + v2.ColorJitter( + brightness=cfg.image_transform.colorjitter_factor, + contrast=cfg.image_transform.colorjitter_factor, + ), + v2.RandomAdjustSharpness( + cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p + ), + v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p), + v2.ToDtype(torch.float32, scale=True), + ] + ) else: transform = None - dataset = LeRobotDataset( - cfg.dataset_repo_id, - split=split, - delta_timestamps=cfg.training.get("delta_timestamps"), - transform=transform - ) + if isinstance(cfg.dataset_repo_id, str): + dataset = LeRobotDataset( + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + transform=transform, + ) + else: + dataset = MultiLeRobotDataset( + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + transform=transform, + ) if cfg.get("override_dataset_stats"): for key, stats_dict in cfg.override_dataset_stats.items(): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d51e1fba..929b83e3 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,12 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from pathlib import Path +from typing import Callable import datasets import torch +import torch.utils +from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.utils import ( calculate_episode_data_index, load_episode_data_index, @@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: callable = None, + transform: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -149,7 +153,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.transform is not None: for cam in self.camera_keys: - item[cam]=self.transform(item[cam]) + item[cam] = self.transform(item[cam]) return item @@ -172,7 +176,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @classmethod def from_preloaded( cls, - repo_id: str, + repo_id: str = "from_preloaded", version: str | None = CODEBASE_VERSION, root: Path | None = None, split: str = "train", @@ -184,7 +188,15 @@ class LeRobotDataset(torch.utils.data.Dataset): stats=None, info=None, videos_dir=None, - ): + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. + + It is especially useful when converting raw data into LeRobotDataset before saving the dataset + on the filesystem or uploading to the hub. + + Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially + meaningless depending on the downstream usage of the return dataset. + """ # create an empty object of type LeRobotDataset obj = cls.__new__(cls) obj.repo_id = repo_id @@ -196,6 +208,193 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index obj.stats = stats - obj.info = info + obj.info = info if info is not None else {} obj.videos_dir = videos_dir return obj + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + version: str | None = CODEBASE_VERSION, + root: Path | None = DATA_DIR, + split: str = "train", + transform: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + version=version, + root=root, + split=split, + delta_timestamps=delta_timestamps, + transform=transform, + ) + for repo_id in repo_ids + ] + # Check that some properties are consistent across datasets. Note: We may relax some of these + # consistency requirements in future iterations of this class. + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + if dataset.info != self._datasets[0].info: + raise ValueError( + f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " + "not yet supported." + ) + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_data_keys = set() + intersection_data_keys = set(self._datasets[0].hf_dataset.features) + for dataset in self._datasets: + intersection_data_keys.intersection_update(dataset.hf_dataset.features) + if len(intersection_data_keys) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. The " + "multi-dataset functionality currently only keeps common keys." + ) + for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_data_keys.update(extra_keys) + + self.version = version + self.root = root + self.split = split + self.transform = transform + self.delta_timestamps = delta_timestamps + self.stats = aggregate_stats(self._datasets) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_samples(self) -> int: + """Number of samples/frames.""" + return sum(d.num_samples for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_samples: + start_idx += dataset.num_samples + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_data_keys: + if data_key in item: + del item[data_key] + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Version: '{self.version}',\n" + f" Split: '{self.split}',\n" + f" Number of Samples: {self.num_samples},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.transform},\n" + f")" + ) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py new file mode 100644 index 00000000..2f6c15c1 --- /dev/null +++ b/lerobot/common/datasets/sampler.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterator, Union + +import torch + + +class EpisodeAwareSampler: + def __init__( + self, + episode_data_index: dict, + episode_indices_to_use: Union[list, None] = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + shuffle: bool = False, + ): + """Sampler that optionally incorporates episode boundary information. + + Args: + episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + episode_indices_to_use: List of episode indices to use. If None, all episodes are used. + Assumes that episodes are indexed from 0 to N-1. + drop_n_first_frames: Number of frames to drop from the start of each episode. + drop_n_last_frames: Number of frames to drop from the end of each episode. + shuffle: Whether to shuffle the indices. + """ + indices = [] + for episode_idx, (start_index, end_index) in enumerate( + zip(episode_data_index["from"], episode_data_index["to"], strict=True) + ): + if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + indices.extend( + range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + ) + + self.indices = indices + self.shuffle = shuffle + + def __iter__(self) -> Iterator[int]: + if self.shuffle: + for i in torch.randperm(len(self.indices)): + yield self.indices[i] + else: + for i in self.indices: + yield i + + def __len__(self) -> int: + return len(self.indices) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 86fef8d4..cb2fee95 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"): return outdict -def hf_transform_to_torch(items_dict): +def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation @@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict): elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: # video frame will be processed downstream pass + elif first_item is None: + pass else: items_dict[key] = [torch.tensor(x) for x in items_dict[key]] return items_dict @@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: - """ - Reset the `episode_index` of the provided HuggingFace Dataset. + """Reset the `episode_index` of the provided HuggingFace Dataset. `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. @@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: return example hf_dataset = hf_dataset.map(modify_ep_idx_func) + return hf_dataset diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95374f4d..a4b0b7d2 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -25,6 +25,13 @@ class ACTConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and 'output_shapes`. + Notes on the inputs and outputs: + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.images." they are treated as multiple camera + views. Right now we only support all images having the same shape. + - May optionally work without an "observation.state" key for the proprioceptive robot state. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). @@ -33,15 +40,15 @@ class ACTConfig: This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.images.top" refers to an input from the - "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesn't include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72ebdd7a..bef59bec 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -198,27 +198,31 @@ class ACT(nn.Module): def __init__(self, config: ACTConfig): super().__init__() self.config = config - # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. + # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + self.use_input_state = "observation.state" in config.input_shapes if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.use_input_state: + self.vae_encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model + config.output_shapes["action"][0], config.dim_model ) - self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) - # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. + num_input_token_encoder = 1 + config.chunk_size + if self.use_input_state: + num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. @@ -238,15 +242,17 @@ class ACT(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) + if self.use_input_state: + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + num_input_token_decoder = 2 if self.use_input_state else 1 + self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. @@ -285,7 +291,7 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["observation.state"].shape[0] + batch_size = batch["observation.images"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -293,11 +299,16 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( - 1 - ) # (B, 1, D) + if self.use_input_state: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + if self.use_input_state: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) # Prepare fixed positional embedding. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. @@ -308,16 +319,17 @@ class ACT(nn.Module): vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) - mu = latent_pdf_params[:, : self.latent_dim] + mu = latent_pdf_params[:, : self.config.latent_dim] # This is 2log(sigma). Done this way to match the original implementation. - log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] # Sample the latent with the reparameterization trick. latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) else: # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None - latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) @@ -326,8 +338,10 @@ class ACT(nn.Module): all_cam_features = [] all_cam_pos_embeds = [] images = batch["observation.images"] + for cam_index in range(images.shape[-4]): cam_features = self.backbone(images[:, cam_index])["feature_map"] + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -337,13 +351,15 @@ class ACT(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + if self.use_input_state: + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed] encoder_in = torch.cat( [ - torch.stack([latent_embed, robot_state_embed], axis=0), + torch.stack(encoder_in_feats, axis=0), einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) @@ -357,6 +373,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer decoder_in = torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 81ff5de7..59ed1656 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -26,21 +26,26 @@ class DiffusionConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and `output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - A key starting with "observation.image is required as an input. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index cf76fb08..49485c39 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -31,6 +31,15 @@ class TDMPCConfig: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google action repeats in Q-learning or ask your favorite chatbot) horizon: Horizon for model predictive control. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 696999ad..c429efbd 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -120,13 +120,13 @@ def init_logging(): logging.getLogger().addHandler(console_handler) -def format_big_number(num): +def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: - return f"{num:.0f}{suffix}" + return f"{num:.{precision}f}{suffix}" num /= divisor return num diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6ceca699..ebabf5e0 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -23,6 +23,10 @@ use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? +# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data +# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the +# "dataset_index" into the returned item. The index mapping is made according to the order in which the +# datsets are provided. dataset_repo_id: lerobot/pusht training: @@ -55,10 +59,10 @@ wandb: notes: "" image_transform: - enable: false + enable: false colorjittor_factor: 0.5 colorjittor_p: 0.5 sharpness_factor: 2 sharpness_p: 0.5 blur_factor: 0.5 - blur_p: 0.5 \ No newline at end of file + blur_p: 0.5 diff --git a/lerobot/configs/env/dora_aloha_real.yaml b/lerobot/configs/env/dora_aloha_real.yaml new file mode 100644 index 00000000..088781d4 --- /dev/null +++ b/lerobot/configs/env/dora_aloha_real.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +fps: 30 + +env: + name: dora + task: DoraAloha-v0 + state_dim: 14 + action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + fps: ${fps} diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml new file mode 100644 index 00000000..b4942615 --- /dev/null +++ b/lerobot/configs/policy/act_real.yaml @@ -0,0 +1,115 @@ +# @package _global_ + +# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets. +# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images, +# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used +# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation. +# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot). +# Look at its README for more information on how to evaluate a checkpoint in the real-world. +# +# Example of usage for training: +# ```bash +# python lerobot/scripts/train.py \ +# policy=act_real \ +# env=dora_aloha_real +# ``` + +seed: 1000 +dataset_repo_id: lerobot/aloha_static_vinh_cup + +override_dataset_stats: + observation.images.cam_right_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_left_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + log_freq: 100 + save_checkpoint: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_right_wrist: [3, 480, 640] + observation.images.cam_left_wrist: [3, 480, 640] + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_right_wrist: mean_std + observation.images.cam_left_wrist: mean_std + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + observation.state: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 diff --git a/lerobot/configs/policy/act_real_no_state.yaml b/lerobot/configs/policy/act_real_no_state.yaml new file mode 100644 index 00000000..a8b1c9b6 --- /dev/null +++ b/lerobot/configs/policy/act_real_no_state.yaml @@ -0,0 +1,111 @@ +# @package _global_ + +# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras) +# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions. +# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might +# overfits to the state, because the images are more complex to learn from since they are moving. +# +# Example of usage for training: +# ```bash +# python lerobot/scripts/train.py \ +# policy=act_real_no_state \ +# env=dora_aloha_real +# ``` + +seed: 1000 +dataset_repo_id: lerobot/aloha_static_vinh_cup + +override_dataset_stats: + observation.images.cam_right_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_left_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + log_freq: 100 + save_checkpoint: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_right_wrist: [3, 480, 640] + observation.images.cam_left_wrist: [3, 480, 640] + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_right_wrist: mean_std + observation.images.cam_left_wrist: mean_std + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 36bd22cc..b04ecf1b 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -44,6 +44,10 @@ training: observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + eval: n_episodes: 50 batch_size: 50 diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index c6eac5e9..52252b57 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -71,9 +71,9 @@ import torch from huggingface_hub import HfApi from safetensors.torch import save_file +from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats from lerobot.common.datasets.utils import flatten_dict diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index eb33b268..860412bd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -16,7 +16,6 @@ import logging import time from contextlib import nullcontext -from copy import deepcopy from pathlib import Path from pprint import pformat @@ -28,6 +27,8 @@ from termcolor import colored from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps +from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset +from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir @@ -280,6 +281,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("make_dataset") offline_dataset = make_dataset(cfg) + if isinstance(offline_dataset, MultiLeRobotDataset): + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(offline_dataset.repo_id_to_index , indent=2)}" + ) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -330,7 +336,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No max_episodes_rendered=4, start_seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") @@ -351,18 +357,28 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info("Resume training") # create dataloader for offline training + if cfg.training.get("drop_n_last_frames"): + shuffle = False + sampler = EpisodeAwareSampler( + offline_dataset.episode_data_index, + drop_n_last_frames=cfg.training.drop_n_last_frames, + shuffle=True, + ) + else: + shuffle = True + sampler = None dataloader = torch.utils.data.DataLoader( offline_dataset, num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, - shuffle=True, + shuffle=shuffle, + sampler=sampler, pin_memory=device.type != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) policy.train() - is_offline = True for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") @@ -382,7 +398,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No ) if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. @@ -390,41 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 - logging.info("End of offline training") - - if cfg.training.online_steps == 0: - if cfg.training.eval_freq > 0: - eval_env.close() - return - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - - # create an empty online dataset similar to offline dataset - online_dataset = deepcopy(offline_dataset) - online_dataset.hf_dataset = {} - online_dataset.episode_data_index = {} - - # create dataloader for online training - concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) - weights = [1.0] * len(concat_dataset) - sampler = torch.utils.data.WeightedRandomSampler( - weights, num_samples=len(concat_dataset), replacement=True - ) - dataloader = torch.utils.data.DataLoader( - concat_dataset, - num_workers=4, - batch_size=cfg.training.batch_size, - sampler=sampler, - pin_memory=device.type != "cpu", - drop_last=False, - ) - - logging.info("End of online training") - - if cfg.training.eval_freq > 0: - eval_env.close() - online_training_env.close() + eval_env.close() + logging.info("End of training") @hydra.main(version_base="1.2", config_name="default", config_path="../configs") diff --git a/poetry.lock b/poetry.lock index 3a04e3d1..e9d6e848 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -444,63 +444,63 @@ files = [ [[package]] name = "coverage" -version = "7.5.1" +version = "7.5.3" description = "Code coverage measurement for Python" optional = true python-versions = ">=3.8" files = [ - {file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"}, - {file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"}, - {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"}, - {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"}, - {file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"}, - {file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"}, - {file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"}, - {file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"}, - {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"}, - {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"}, - {file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"}, - {file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"}, - {file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"}, - {file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"}, - {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"}, - {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"}, - {file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"}, - {file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"}, - {file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"}, - {file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"}, - {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"}, - {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"}, - {file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"}, - {file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"}, - {file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"}, - {file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"}, - {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"}, - {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"}, - {file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"}, - {file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"}, - {file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"}, - {file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"}, + {file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"}, + {file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"}, + {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"}, + {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"}, + {file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"}, + {file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"}, + {file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"}, + {file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"}, + {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"}, + {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"}, + {file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"}, + {file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, + {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, + {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, + {file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"}, + {file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"}, + {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"}, + {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"}, + {file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"}, + {file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"}, + {file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"}, + {file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"}, + {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"}, + {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"}, + {file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"}, + {file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"}, + {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, + {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, ] [package.dependencies] @@ -785,6 +785,26 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "dora-rs" +version = "0.3.4" +description = "`dora` goal is to be a low latency, composable, and distributed data flow." +optional = true +python-versions = "*" +files = [ + {file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"}, + {file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"}, + {file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"}, + {file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"}, + {file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"}, + {file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"}, +] + +[package.dependencies] +pyarrow = "*" + [[package]] name = "einops" version = "0.8.0" @@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0" dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"] test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"] +[[package]] +name = "gym-dora" +version = "0.1.0" +description = "" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dora-rs = ">=0.3.4" +gymnasium = ">=0.29.1" +pyarrow = ">=12.0.0" + +[package.source] +type = "git" +url = "https://github.com/dora-rs/dora-lerobot.git" +reference = "HEAD" +resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79" +subdirectory = "gym_dora" + [[package]] name = "gym-pusht" version = "0.1.4" @@ -1269,13 +1310,13 @@ files = [ [[package]] name = "huggingface-hub" -version = "0.23.1" +version = "0.23.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"}, - {file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"}, + {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, + {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, ] [package.dependencies] @@ -2061,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "nodeenv" -version = "1.8.0" +version = "1.9.0" description = "Node.js virtual environment builder" optional = true -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ - {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, - {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, + {file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"}, + {file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"}, ] -[package.dependencies] -setuptools = "*" - [[package]] name = "numba" version = "0.59.1" @@ -2406,6 +2444,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -2426,6 +2465,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3188,13 +3228,13 @@ files = [ [[package]] name = "requests" -version = "2.32.2" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, - {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -3210,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rerun-sdk" -version = "0.16.0" +version = "0.16.1" description = "The Rerun Logging SDK" optional = false python-versions = "<3.13,>=3.8" files = [ - {file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"}, - {file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"}, + {file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"}, ] [package.dependencies] @@ -3696,17 +3736,17 @@ files = [ [[package]] name = "sympy" -version = "1.12" +version = "1.12.1" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, + {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, + {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, ] [package.dependencies] -mpmath = ">=0.19" +mpmath = ">=1.1.0,<1.4.0" [[package]] name = "tbb" @@ -4220,13 +4260,13 @@ multidict = ">=4.0" [[package]] name = "zarr" -version = "2.18.1" +version = "2.18.2" description = "An implementation of chunked, compressed, N-dimensional arrays for Python" optional = false python-versions = ">=3.9" files = [ - {file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"}, - {file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"}, + {file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"}, + {file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"}, ] [package.dependencies] @@ -4241,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] [[package]] name = "zipp" -version = "3.18.2" +version = "3.19.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"}, - {file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"}, + {file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"}, + {file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"}, ] [package.extras] @@ -4257,6 +4297,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [extras] aloha = ["gym-aloha"] dev = ["debugpy", "pre-commit"] +dora = ["gym-dora"] pusht = ["gym-pusht"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] @@ -4265,4 +4306,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6" +content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" diff --git a/pyproject.toml b/pyproject.toml index 1dd8b1d6..0c305218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ h5py = ">=3.10.0" huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} gymnasium = ">=0.29.1" cmake = ">=3.29.0.1" +gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true } gym-pusht = { version = ">=0.1.3", optional = true} gym-xarm = { version = ">=0.1.1", optional = true} gym-aloha = { version = ">=0.1.1", optional = true} @@ -62,6 +63,7 @@ deepdiff = ">=7.0.1" [tool.poetry.extras] +dora = ["gym-dora"] pusht = ["gym-pusht"] xarm = ["gym-xarm"] aloha = ["gym-aloha"] diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors new file mode 100644 index 00000000..2373f1ee --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4 +size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors new file mode 100644 index 00000000..de40a20e --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901 +size 31688 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors new file mode 100644 index 00000000..8602cc56 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a +size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors new file mode 100644 index 00000000..a6612b7f --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99 +size 34928 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors new file mode 100644 index 00000000..9f0ba883 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716 +size 5104 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors new file mode 100644 index 00000000..2b01b94c --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b +size 30808 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors new file mode 100644 index 00000000..c2417bf8 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386 +size 68 diff --git a/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors new file mode 100644 index 00000000..335d2a55 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/dora_aloha_real_act_real_no_state/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6 +size 33608 diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensors.py similarity index 91% rename from tests/scripts/save_policy_to_safetensor.py rename to tests/scripts/save_policy_to_safetensors.py index 89f33374..961b7cef 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides): # HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension dataset.delta_timestamps = None batch = next(iter(dataloader)) - obs = { - k: batch[k] - for k in batch - if k in ["observation.image", "observation.images.top", "observation.state"] - } + obs = {} + for k in batch: + if k.startswith("observation"): + obs[k] = batch[k] + + if "n_action_steps" in cfg.policy: + actions_queue = cfg.policy.n_action_steps + else: + actions_queue = cfg.policy.n_action_repeats - actions_queue = ( - cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats - ) actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} return output_dict, grad_stats, param_stats, actions @@ -114,6 +115,8 @@ if __name__ == "__main__": ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index afea16a5..da0ae755 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -16,6 +16,7 @@ import json import logging from copy import deepcopy +from itertools import chain from pathlib import Path import einops @@ -25,26 +26,34 @@ from datasets import Dataset from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.lerobot_dataset import ( - LeRobotDataset, -) -from lerobot.common.datasets.push_dataset_to_hub.compute_stats import ( +from lerobot.common.datasets.compute_stats import ( + aggregate_stats, compute_stats, get_stats_einops_patterns, ) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.utils import ( flatten_dict, hf_transform_to_torch, load_previous_and_future_frames, unflatten_dict, ) -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, seeded_context from tests.utils import DEFAULT_CONFIG_PATH, DEVICE -@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets) +@pytest.mark.parametrize( + "env_name, repo_id, policy_name", + lerobot.env_dataset_policy_triplets + + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")], +) def test_factory(env_name, repo_id, policy_name): + """ + Tests that: + - we can create a dataset with the factory. + - for a commonly used set of data keys, the data dimensions are correct. + """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, overrides=[ @@ -105,6 +114,39 @@ def test_factory(env_name, repo_id, policy_name): assert key in item, f"{key}" +# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. +def test_multilerobotdataset_frames(): + """Check that all dataset frames are incorporated.""" + # Note: use the image variants of the dataset to make the test approx 3x faster. + # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining + # logic that wouldn't be caught with two repo IDs. + repo_ids = [ + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + ] + sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] + dataset = MultiLeRobotDataset(repo_ids) + assert len(dataset) == sum(len(d) for d in sub_datasets) + assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) + assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) + + # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and + # check they match. + expected_dataset_indices = [] + for i, sub_dataset in enumerate(sub_datasets): + expected_dataset_indices.extend([i] * len(sub_dataset)) + + for expected_dataset_index, sub_dataset_item, dataset_item in zip( + expected_dataset_indices, chain(*sub_datasets), dataset, strict=True + ): + dataset_index = dataset_item.pop("dataset_index") + assert dataset_index == expected_dataset_index + assert sub_dataset_item.keys() == dataset_item.keys() + for k in sub_dataset_item: + assert torch.equal(sub_dataset_item[k], dataset_item[k]) + + def test_compute_stats_on_xarm(): """Check that the statistics are computed correctly according to the stats_patterns property. @@ -315,3 +357,31 @@ def test_backward_compatibility(repo_id): # i = dataset.episode_data_index["to"][-1].item() # load_and_compare(i - 2) # load_and_compare(i - 1) + + +def test_aggregate_stats(): + """Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" + with seeded_context(0): + data_a = torch.rand(30, dtype=torch.float32) + data_b = torch.rand(20, dtype=torch.float32) + data_c = torch.rand(20, dtype=torch.float32) + + hf_dataset_1 = Dataset.from_dict( + {"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)} + ) + hf_dataset_1.set_transform(hf_transform_to_torch) + hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}) + hf_dataset_2.set_transform(hf_transform_to_torch) + hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}) + hf_dataset_3.set_transform(hf_transform_to_torch) + dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1) + dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0) + dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2) + dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0) + dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3) + dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0) + stats = aggregate_stats([dataset_1, dataset_2, dataset_3]) + for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): + for agg_fn in ["mean", "min", "max"]: + assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) + assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0)) diff --git a/tests/test_policies.py b/tests/test_policies.py index bb0c7b80..c099bef0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config -from tests.scripts.save_policy_to_safetensor import get_policy_stats +from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str): ), # Note: these parameters also need custom logic in the test function for overriding the Hydra config. ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]), + ("dora_aloha_real", "act_real", []), + ("dora_aloha_real", "act_real_no_state", []), ], ) @require_env @@ -84,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy + + Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, + and for now we add tests as we see fit. """ cfg = init_hydra_config( DEFAULT_CONFIG_PATH, @@ -135,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides): dataloader = torch.utils.data.DataLoader( dataset, - num_workers=4, + num_workers=0, batch_size=2, shuffle=True, pin_memory=DEVICE != "cpu", @@ -291,6 +296,8 @@ def test_normalize(insert_temporal_dim): ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..635e7f11 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datasets import Dataset + +from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.utils import ( + calculate_episode_data_index, + hf_transform_to_torch, +) + + +def test_drop_n_first_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) + assert sampler.indices == [1, 4, 5] + assert len(sampler) == 3 + assert list(sampler) == [1, 4, 5] + + +def test_drop_n_last_frames(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) + assert sampler.indices == [0, 3, 4] + assert len(sampler) == 3 + assert list(sampler) == [0, 3, 4] + + +def test_episode_indices_to_use(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) + assert sampler.indices == [0, 1, 3, 4, 5] + assert len(sampler) == 5 + assert list(sampler) == [0, 1, 3, 4, 5] + + +def test_shuffle(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) + assert sampler.indices == [0, 1, 2, 3, 4, 5] + assert len(sampler) == 6 + assert list(sampler) == [0, 1, 2, 3, 4, 5] + sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) + assert sampler.indices == [0, 1, 2, 3, 4, 5] + assert len(sampler) == 6 + assert set(sampler) == {0, 1, 2, 3, 4, 5}