diff --git a/Makefile b/Makefile index 9bac437d..f6517497 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,7 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train + ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval ${MAKE} DEVICE=$(DEVICE) test-default-ete-eval ${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial @@ -113,7 +114,6 @@ test-diffusion-ete-eval: env.episode_length=8 \ device=$(DEVICE) \ -# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ @@ -133,6 +133,28 @@ test-tdmpc-ete-train: training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/tdmpc/ +test-tdmpc-ete-train-with-online: + python lerobot/scripts/train.py \ + env=pusht \ + env.gym.obs_type=environment_state_agent_pos \ + policy=tdmpc_pusht_keypoints \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + env.episode_length=10 \ + device=$(DEVICE) \ + training.offline_steps=2 \ + training.online_steps=20 \ + training.save_checkpoint=false \ + training.save_freq=10 \ + training.batch_size=2 \ + training.online_rollout_n_episodes=2 \ + training.online_rollout_batch_size=2 \ + training.online_steps_between_rollouts=10 \ + training.online_buffer_capacity=15 \ + eval.use_async_envs=true \ + hydra.run.dir=tests/outputs/tdmpc_online/ + + test-tdmpc-ete-eval: python lerobot/scripts/eval.py \ -p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py new file mode 100644 index 00000000..6b093cda --- /dev/null +++ b/lerobot/common/datasets/online_buffer.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An online buffer for the online training loop in train.py + +Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should +consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much +faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it +supports in-place slicing and mutation which is very handy for a dynamic buffer. +""" + +import os +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def _make_memmap_safe(**kwargs) -> np.memmap: + """Make a numpy memmap with checks on available disk space first. + + Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" + + For information on dtypes: + https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing + """ + if kwargs["mode"].startswith("w"): + required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes + stats = os.statvfs(Path(kwargs["filename"]).parent) + available_space = stats.f_bavail * stats.f_frsize # bytes + if required_space >= available_space * 0.8: + raise RuntimeError( + f"You're about to take up {required_space} of {available_space} bytes available." + ) + return np.memmap(**kwargs) + + +class OnlineBuffer(torch.utils.data.Dataset): + """FIFO data buffer for the online training loop in train.py. + + Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training + loop in the same way that a LeRobotDataset would be used. + + The underlying data structure will have data inserted in a circular fashion. Always insert after the + last index, and when you reach the end, wrap around to the start. + + The data is stored in a numpy memmap. + """ + + NEXT_INDEX_KEY = "_next_index" + OCCUPANCY_MASK_KEY = "_occupancy_mask" + INDEX_KEY = "index" + FRAME_INDEX_KEY = "frame_index" + EPISODE_INDEX_KEY = "episode_index" + TIMESTAMP_KEY = "timestamp" + IS_PAD_POSTFIX = "_is_pad" + + def __init__( + self, + write_dir: str | Path, + data_spec: dict[str, Any] | None, + buffer_capacity: int | None, + fps: float | None = None, + delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, + ): + """ + The online buffer can be provided from scratch or you can load an existing online buffer by passing + a `write_dir` associated with an existing buffer. + + Args: + write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. + Note that if the files already exist, they are opened in read-write mode (used for training + resumption.) + data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], + "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, + but note that "index", "frame_index" and "episode_index" are already accounted for by this + class, so you don't need to include them. + buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your + system's available disk space when choosing this. + fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the + delta_timestamps logic. You can pass None if you are not using delta_timestamps. + delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally + converted to dict[str, np.ndarray] for optimization purposes. + + """ + self.set_delta_timestamps(delta_timestamps) + self._fps = fps + # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from + # the requested frames. It is only used when `delta_timestamps` is provided. + # minus 1e-4 to account for possible numerical error + self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None + self._buffer_capacity = buffer_capacity + data_spec = self._make_data_spec(data_spec, buffer_capacity) + Path(write_dir).mkdir(parents=True, exist_ok=True) + self._data = {} + for k, v in data_spec.items(): + self._data[k] = _make_memmap_safe( + filename=Path(write_dir) / k, + dtype=v["dtype"] if v is not None else None, + mode="r+" if (Path(write_dir) / k).exists() else "w+", + shape=tuple(v["shape"]) if v is not None else None, + ) + + @property + def delta_timestamps(self) -> dict[str, np.ndarray] | None: + return self._delta_timestamps + + def set_delta_timestamps(self, value: dict[str, list[float]] | None): + """Set delta_timestamps converting the values to numpy arrays. + + The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays + need to be converted into numpy arrays. + """ + if value is not None: + self._delta_timestamps = {k: np.array(v) for k, v in value.items()} + else: + self._delta_timestamps = None + + def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: + """Makes the data spec for np.memmap.""" + if any(k.startswith("_") for k in data_spec): + raise ValueError( + "data_spec keys should not start with '_'. This prefix is reserved for internal logic." + ) + preset_keys = { + OnlineBuffer.INDEX_KEY, + OnlineBuffer.FRAME_INDEX_KEY, + OnlineBuffer.EPISODE_INDEX_KEY, + OnlineBuffer.TIMESTAMP_KEY, + } + if len(intersection := set(data_spec).intersection(preset_keys)) > 0: + raise ValueError( + f"data_spec should not contain any of {preset_keys} as these are handled internally. " + f"The provided data_spec has {intersection}." + ) + complete_data_spec = { + # _next_index will be a pointer to the next index that we should start filling from when we add + # more data. + OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, + # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied + # with real data rather than the dummy initialization. + OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, + OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, + OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, + } + for k, v in data_spec.items(): + complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} + return complete_data_spec + + def add_data(self, data: dict[str, np.ndarray]): + """Add new data to the buffer, which could potentially mean shifting old data out. + + The new data should contain all the frames (in order) of any number of episodes. The indices should + start from 0 (note to the developer: this can easily be generalized). See the `rollout` and + `eval_policy` functions in `eval.py` for more information on how the data is constructed. + + Shift the incoming data index and episode_index to continue on from the last frame. Note that this + will be done in place! + """ + if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: + raise ValueError(f"Missing data keys: {missing_keys}") + new_data_length = len(data[self.data_keys[0]]) + if not all(len(data[k]) == new_data_length for k in self.data_keys): + raise ValueError("All data items should have the same length") + + next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] + + # Sanity check to make sure that the new data indices start from 0. + assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 + assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 + + # Shift the incoming indices if necessary. + if self.num_samples > 0: + last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] + last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] + data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 + data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 + + # Insert the new data starting from next_index. It may be necessary to wrap around to the start. + n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) + for k in self.data_keys: + if n_surplus == 0: + slc = slice(next_index, next_index + new_data_length) + self._data[k][slc] = data[k] + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True + else: + self._data[k][next_index:] = data[k][:-n_surplus] + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True + self._data[k][:n_surplus] = data[k][-n_surplus:] + if n_surplus == 0: + self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length + else: + self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus + + @property + def data_keys(self) -> list[str]: + keys = set(self._data) + keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) + keys.remove(OnlineBuffer.NEXT_INDEX_KEY) + return sorted(keys) + + @property + def fps(self) -> float | None: + return self._fps + + @property + def num_episodes(self) -> int: + return len( + np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + ) + + @property + def num_samples(self) -> int: + return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) + + def __len__(self): + return self.num_samples + + def _item_to_tensors(self, item: dict) -> dict: + item_ = {} + for k, v in item.items(): + if isinstance(v, torch.Tensor): + item_[k] = v + elif isinstance(v, np.ndarray): + item_[k] = torch.from_numpy(v) + else: + item_[k] = torch.tensor(v) + return item_ + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self) or idx < -len(self): + raise IndexError + + item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} + + if self.delta_timestamps is None: + return self._item_to_tensors(item) + + episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] + current_ts = item[OnlineBuffer.TIMESTAMP_KEY] + episode_data_indices = np.where( + np.bitwise_and( + self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], + ) + )[0] + episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] + + for data_key in self.delta_timestamps: + # Note: The logic in this loop is copied from `load_previous_and_future_frames`. + # Get timestamps used as query to retrieve data of previous/future frames. + query_ts = current_ts + self.delta_timestamps[data_key] + + # Compute distances between each query timestamp and all timestamps of all the frames belonging to + # the episode. + dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) + argmin_ = np.argmin(dist, axis=1) + min_ = dist[np.arange(dist.shape[0]), argmin_] + + is_pad = min_ > self.tolerance_s + + # Check violated query timestamps are all outside the episode range. + assert ( + (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) + ).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" + ") inside the episode range." + ) + + # Load frames for this data key. + item[data_key] = self._data[data_key][episode_data_indices[argmin_]] + + item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad + + return self._item_to_tensors(item) + + def get_data_by_key(self, key: str) -> torch.Tensor: + """Returns all data for a given data key as a Tensor.""" + return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + + +def compute_sampler_weights( + offline_dataset: LeRobotDataset, + offline_drop_n_last_frames: int = 0, + online_dataset: OnlineBuffer | None = None, + online_sampling_ratio: float | None = None, + online_drop_n_last_frames: int = 0, +) -> torch.Tensor: + """Compute the sampling weights for the online training dataloader in train.py. + + Args: + offline_dataset: The LeRobotDataset used for offline pre-training. + online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. + online_dataset: The OnlineBuffer used in online training. + online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an + online dataset is provided, this value must also be provided. + online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online + dataset. + Returns: + Tensor of weights for [offline_dataset; online_dataset], normalized to 1. + + Notes to maintainers: + - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. + - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace + `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature + is the ability to turn shuffling off. + - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not + included here to avoid adding complexity. + """ + if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): + raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") + if (online_dataset is None) ^ (online_sampling_ratio is None): + raise ValueError( + "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." + ) + offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + + weights = [] + + if len(offline_dataset) > 0: + offline_data_mask_indices = [] + for start_index, end_index in zip( + offline_dataset.episode_data_index["from"], + offline_dataset.episode_data_index["to"], + strict=True, + ): + offline_data_mask_indices.extend( + range(start_index.item(), end_index.item() - offline_drop_n_last_frames) + ) + offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) + offline_data_mask[torch.tensor(offline_data_mask_indices)] = True + weights.append( + torch.full( + size=(len(offline_dataset),), + fill_value=offline_sampling_ratio / offline_data_mask.sum(), + ) + * offline_data_mask + ) + + if online_dataset is not None and len(online_dataset) > 0: + online_data_mask_indices = [] + episode_indices = online_dataset.get_data_by_key("episode_index") + for episode_idx in torch.unique(episode_indices): + where_episode = torch.where(episode_indices == episode_idx) + start_index = where_episode[0][0] + end_index = where_episode[0][-1] + 1 + online_data_mask_indices.extend( + range(start_index.item(), end_index.item() - online_drop_n_last_frames) + ) + online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) + online_data_mask[torch.tensor(online_data_mask_indices)] = True + weights.append( + torch.full( + size=(len(online_dataset),), + fill_value=online_sampling_ratio / online_data_mask.sum(), + ) + * online_data_mask + ) + + weights = torch.cat(weights) + + if weights.sum() == 0: + weights += 1 / len(weights) + else: + weights /= weights.sum() + + return weights diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 49485c39..4a5415a1 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -25,12 +25,16 @@ class TDMPCConfig: camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`. + Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. Args: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google action repeats in Q-learning or ask your favorite chatbot) horizon: Horizon for model predictive control. + n_action_steps: Number of action steps to take from the plan given by model predictive control. This + is an alternative to using action repeats. If this is set to more than 1, then we require + `n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this + approach of using multiple steps from the plan is not in the original implementation. input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents the input data name, and the value is a list indicating the dimensions of the corresponding data. For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], @@ -100,6 +104,7 @@ class TDMPCConfig: # Input / output structure. n_action_repeats: int = 2 horizon: int = 5 + n_action_steps: int = 1 input_shapes: dict[str, list[int]] = field( default_factory=lambda: { @@ -158,17 +163,18 @@ class TDMPCConfig: """Input validation (not exhaustive).""" # There should only be one image key. image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if len(image_keys) != 1: + if len(image_keys) > 1: raise ValueError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) - image_key = next(iter(image_keys)) - if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: - # TODO(alexander-soare): This limitation is solely because of code in the random shift - # augmentation. It should be able to be removed. - raise ValueError( - f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." + f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}." ) + if len(image_keys) > 0: + image_key = next(iter(image_keys)) + if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]: + # TODO(alexander-soare): This limitation is solely because of code in the random shift + # augmentation. It should be able to be removed. + raise ValueError( + f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}." + ) if self.n_gaussian_samples <= 0: raise ValueError( f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`" @@ -179,3 +185,12 @@ class TDMPCConfig: f"advised that you stick with the default. See {self.__class__.__name__} docstring for more " "information." ) + if self.n_action_steps > 1: + if self.n_action_repeats != 1: + raise ValueError( + "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." + ) + if not self.use_mpc: + raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") + if self.n_action_steps > self.horizon: + raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 020e48a2..7dbffcef 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -19,14 +19,10 @@ The comments in this code may sometimes refer to these references: TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) - -TODO(alexander-soare): Make rollout work for batch sizes larger than 1. -TODO(alexander-soare): Use batch-first throughout. """ # ruff: noqa: N806 -import logging from collections import deque from copy import deepcopy from functools import partial @@ -56,9 +52,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): process communication to use the xarm environment from FOWM. This is because our xarm environment uses newer dependencies and does not match the environment in FOWM. See https://github.com/huggingface/lerobot/pull/103 for implementation details. - - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. + - We have NOT checked that training on LeRobot reproduces the results from FOWM. + - Nevertheless, we have verified that we can train TD-MPC for PushT. See + `lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`. - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not - match our xarm environment. + match our xarm environment. """ name = "tdmpc" @@ -74,22 +72,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__() - logging.warning( - """ - Please note several warnings for this policy. - - - Evaluation of pretrained weights created with the original FOWM code - (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a - model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across - to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- - process communication to use the xarm environment from FOWM. This is because our xarm - environment uses newer dependencies and does not match the environment in FOWM. See - https://github.com/huggingface/lerobot/pull/103 for implementation details. - - We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO. - - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not - match our xarm environment. - """ - ) if config is None: config = TDMPCConfig() @@ -114,8 +96,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] # Note: This check is covered in the post-init of the config but have a sanity check just in case. - assert len(image_keys) == 1 - self.input_image_key = image_keys[0] + self._use_image = False + self._use_env_state = False + if len(image_keys) > 0: + assert len(image_keys) == 1 + self._use_image = True + self.input_image_key = image_keys[0] + if "observation.environment_state" in config.input_shapes: + self._use_env_state = True self.reset() @@ -125,10 +113,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): called on `env.reset()` """ self._queues = { - "observation.image": deque(maxlen=1), "observation.state": deque(maxlen=1), - "action": deque(maxlen=self.config.n_action_repeats), + "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } + if self._use_image: + self._queues["observation.image"] = deque(maxlen=1) + if self._use_env_state: + self._queues["observation.environment_state"] = deque(maxlen=1) # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # CEM for the next step. self._prev_mean: torch.Tensor | None = None @@ -137,8 +128,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[self.input_image_key] + if self._use_image: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -152,49 +144,57 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): batch[key] = batch[key][:, 0] # NOTE: Order of observations matters here. - z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]}) - if self.config.use_mpc: - batch_size = batch["observation.image"].shape[0] - # Batch processing is not handled in MPC mode, so process the batch in a loop. - action = [] # will be a batch of actions for one step - for i in range(batch_size): - # Note: self.plan does not handle batches, hence the squeeze. - action.append(self.plan(z[i])) - action = torch.stack(action) + encode_keys = [] + if self._use_image: + encode_keys.append("observation.image") + if self._use_env_state: + encode_keys.append("observation.environment_state") + encode_keys.append("observation.state") + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) else: - # Plan with the policy (π) alone. - action = self.model.pi(z) + # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) - self.unnormalize_outputs({"action": action})["action"] + actions = torch.clamp(actions, -1, +1) - for _ in range(self.config.n_action_repeats): - self._queues["action"].append(action) + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.n_action_repeats > 1: + for _ in range(self.config.n_action_repeats): + self._queues["action"].append(actions[0]) + else: + # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. + self._queues["action"].extend(actions[: self.config.n_action_steps]) action = self._queues["action"].popleft() - return torch.clamp(action, -1, 1) + return action @torch.no_grad() def plan(self, z: Tensor) -> Tensor: - """Plan next action using TD-MPC inference. + """Plan sequence of actions using TD-MPC inference. Args: - z: (latent_dim,) tensor for the initial state. + z: (batch, latent_dim,) tensor for the initial state. Returns: - (action_dim,) tensor for the next action. - - TODO(alexander-soare) Extend this to be able to work with batches. + (horizon, batch, action_dim,) tensor for the planned trajectory of actions. """ device = get_device_from_parameters(self) + batch_size = z.shape[0] + # Sample Nπ trajectories from the policy. pi_actions = torch.empty( self.config.horizon, self.config.n_pi_samples, + batch_size, self.config.output_shapes["action"][0], device=device, ) if self.config.n_pi_samples > 0: - _z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples) + _z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples) for t in range(self.config.horizon): # Note: Adding a small amount of noise here doesn't hurt during inference and may even be # helpful for CEM. @@ -203,12 +203,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # trajectories. - z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization # algorithm. # The initial mean and standard deviation for the cross-entropy method (CEM). - mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device) + mean = torch.zeros( + self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device + ) # Maybe warm start CEM with the mean from the previous step. if self._prev_mean is not None: mean[:-1] = self._prev_mean[1:] @@ -219,6 +221,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): std_normal_noise = torch.randn( self.config.horizon, self.config.n_gaussian_samples, + batch_size, self.config.output_shapes["action"][0], device=std.device, ) @@ -227,21 +230,24 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Compute elite actions. actions = torch.cat([gaussian_actions, pi_actions], dim=1) value = self.estimate_value(z, actions).nan_to_num_(0) - elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) + elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) + # (horizon, n_elites, batch, action_dim) + elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) - # Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites. - max_value = elite_value.max(0)[0] + # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. + max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) # The weighting is a softmax over trajectory values. Note that this is not the same as the usage # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) - score /= score.sum() - _mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1) + score /= score.sum(axis=0, keepdim=True) + # (horizon, batch, action_dim) + _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) _std = torch.sqrt( torch.sum( - einops.rearrange(score, "n -> n 1") - * (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2, + einops.rearrange(score, "n b -> n b 1") + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, dim=1, ) ) @@ -256,11 +262,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # scores from the last iteration. - actions = elite_actions[:, torch.multinomial(score, 1).item()] + actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] - # Select only the first action - action = actions[0] - return action + return actions @torch.no_grad() def estimate_value(self, z: Tensor, actions: Tensor): @@ -312,13 +316,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) return G - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Run the batch through the model and compute the loss.""" + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: + """Run the batch through the model and compute the loss. + + Returns a dictionary with loss as a tensor, and other information as native floats. + """ device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[self.input_image_key] + if self._use_image: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) info = {} @@ -328,12 +336,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): if batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - action = batch["action"] # (t, b) - reward = batch["next.reward"] # (t,) + action = batch["action"] # (t, b, action_dim) + reward = batch["next.reward"] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} # Apply random image augmentations. - if self.config.max_random_shift_ratio > 0: + if self._use_image and self.config.max_random_shift_ratio > 0: observations["observation.image"] = flatten_forward_unflatten( partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), observations["observation.image"], @@ -345,7 +353,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): for k in observations: current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] - horizon = next_observations["observation.image"].shape[0] + horizon, batch_size = next_observations[ + "observation.image" if self._use_image else "observation.environment_state" + ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action @@ -415,7 +425,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. q_value_loss = ( ( - F.mse_loss( + temporal_loss_coeffs + * F.mse_loss( q_preds_ensemble, einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), reduction="none", @@ -464,10 +475,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) # Calculate the MSE between the actions and the action predictions. # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation - # gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying - # the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset - # as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration - # parameter for it (see below where we compute the total loss). + # gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to + # multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action + # dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop + # the 0.5 as we instead make a configuration parameter for it (see below where we compute the total + # loss). mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) # NOTE: The original implementation does not take the sum over the temporal dimension like with the # other losses. @@ -728,6 +740,16 @@ class TDMPCObservationEncoder(nn.Module): nn.LayerNorm(config.latent_dim), nn.Sigmoid(), ) + if "observation.environment_state" in config.input_shapes: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + ), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -736,8 +758,11 @@ class TDMPCObservationEncoder(nn.Module): over all features. """ feat = [] + # NOTE: Order of observations matters here. if "observation.image" in self.config.input_shapes: feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"])) + if "observation.environment_state" in self.config.input_shapes: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) return torch.stack(feat, dim=0).mean(0) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 4bb1508d..a3ff1d41 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -32,19 +32,54 @@ video_backend: pyav training: offline_steps: ??? - # NOTE: `online_steps` is not implemented yet. It's here as a placeholder. - online_steps: ??? - online_steps_between_rollouts: ??? - online_sampling_ratio: 0.5 - # `online_env_seed` is used for environments for online training data rollouts. - online_env_seed: ??? + + # Number of workers for the offline training dataloader. + num_workers: 4 + + batch_size: ??? + eval_freq: ??? log_freq: 200 save_checkpoint: true # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: ??? - num_workers: 4 - batch_size: ??? + + # Online training. Note that the online training loop adopts most of the options above apart from the + # dataloader options. Unless otherwise specified. + # The online training look looks something like: + # + # for i in range(online_steps): + # do_online_rollout_and_update_online_buffer() + # for j in range(online_steps_between_rollouts): + # batch = next(dataloader_with_offline_and_online_data) + # loss = policy(batch) + # loss.backward() + # optimizer.step() + # + online_steps: ??? + # How many episodes to collect at once when we reach the online rollout part of the training loop. + online_rollout_n_episodes: 1 + # The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for + # the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes. + online_rollout_batch_size: 1 + # How many optimization steps (forward, backward, optimizer step) to do between running rollouts. + online_steps_between_rollouts: null + # The proportion of online samples (vs offline samples) to include in the online training batches. + online_sampling_ratio: 0.5 + # First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1. + online_env_seed: null + # Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is + # FIFO. + online_buffer_capacity: null + # The minimum number of frames to have in the online buffer before commencing online training. + # If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the + # seed size condition is satisfied. + online_buffer_seed_size: 0 + # Whether to run the online rollouts asynchronously. This means we can run the online training steps in + # parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training + # + eval + environment rendering simultaneously. + do_online_rollout_async: false + image_transforms: # These transforms are all using standard torchvision.transforms.v2 # You can find out how these transformations affect images here: diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 9dbb96f5..4320379a 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -9,7 +9,7 @@ env: state_dim: 4 action_dim: 4 fps: ${fps} - episode_length: 25 + episode_length: 200 gym: obs_type: pixels_agent_pos render_mode: rgb_array diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 379e9320..40eab35f 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -4,19 +4,30 @@ seed: 1 dataset_repo_id: lerobot/xarm_lift_medium training: - offline_steps: 25000 - # TODO(alexander-soare): uncomment when online training gets reinstated - online_steps: 0 # 25000 not implemented yet - eval_freq: 5000 - online_steps_between_rollouts: 1 - online_sampling_ratio: 0.5 - online_env_seed: 10000 - log_freq: 100 + offline_steps: 50000 + + num_workers: 4 batch_size: 256 grad_clip_norm: 10.0 lr: 3e-4 + eval_freq: 5000 + log_freq: 100 + + online_steps: 50000 + online_rollout_n_episodes: 1 + online_rollout_batch_size: 1 + # Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of + # the last sampled episode. + online_steps_between_rollouts: 50 + online_sampling_ratio: 0.5 + online_env_seed: 10000 + # FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task + # is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use + # 80000. + online_buffer_capacity: 80000 + delta_timestamps: observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]" observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" @@ -31,6 +42,7 @@ policy: # Input / output structure. n_action_repeats: 2 horizon: 5 + n_action_steps: 1 input_shapes: # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? diff --git a/lerobot/configs/policy/tdmpc_pusht_keypoints.yaml b/lerobot/configs/policy/tdmpc_pusht_keypoints.yaml new file mode 100644 index 00000000..1cfc5b52 --- /dev/null +++ b/lerobot/configs/policy/tdmpc_pusht_keypoints.yaml @@ -0,0 +1,105 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# env=pusht \ +# env.gym.obs_type=environment_state_agent_pos \ +# policy=tdmpc_pusht_keypoints \ +# eval.batch_size=50 \ +# eval.n_episodes=50 \ +# eval.use_async_envs=true \ +# device=cuda \ +# use_amp=true + +seed: 1 +dataset_repo_id: lerobot/pusht_keypoints + +training: + offline_steps: 0 + + # Offline training dataloader + num_workers: 4 + + batch_size: 256 + grad_clip_norm: 10.0 + lr: 3e-4 + + eval_freq: 10000 + log_freq: 500 + save_freq: 50000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 40000 + online_buffer_seed_size: 0 + do_online_rollout_async: false + + delta_timestamps: + observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: tdmpc + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 5 + n_action_steps: 5 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.environment_state: [16] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.environment_state: min_max + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + state_encoder_hidden_dim: 256 + latent_dim: 50 + q_ensemble_size: 5 + mlp_dim: 512 + # Reinforcement learning. + discount: 0.98 + + # Inference. + use_mpc: true + cem_iterations: 6 + max_std: 2.0 + min_std: 0.05 + n_gaussian_samples: 512 + n_pi_samples: 51 + uncertainty_regularizer_coeff: 1.0 + n_elites: 50 + elite_weighting_temperature: 0.5 + gaussian_mean_momentum: 0.1 + + # Training and loss computation. + max_random_shift_ratio: 0.0476 + # Loss coefficients. + reward_coeff: 0.5 + expectile_weight: 0.9 + value_coeff: 0.1 + consistency_coeff: 20.0 + advantage_scaling: 3.0 + pi_coeff: 0.5 + temporal_decay_coeff: 0.5 + # Target model. + target_model_momentum: 0.995 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 486b4d2b..a07f3530 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -56,16 +56,13 @@ import einops import gymnasium as gym import numpy as np import torch -from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets from huggingface_hub import snapshot_download from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError -from PIL import Image as PILImage from torch import Tensor, nn from tqdm import trange from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.utils import hf_transform_to_torch from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation from lerobot.common.logger import log_output_dir @@ -318,41 +315,17 @@ def eval_policy( rollout_data, done_indices, start_episode_index=batch_ix * env.num_envs, - start_data_index=( - 0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item()) - ), + start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), fps=env.unwrapped.metadata["render_fps"], ) if episode_data is None: episode_data = this_episode_data else: - # Some sanity checks to make sure we are not correctly compiling the data. - assert ( - episode_data["hf_dataset"]["episode_index"][-1] + 1 - == this_episode_data["hf_dataset"]["episode_index"][0] - ) - assert ( - episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0] - ) - assert torch.equal( - episode_data["episode_data_index"]["to"][-1], - this_episode_data["episode_data_index"]["from"][0], - ) + # Some sanity checks to make sure we are correctly compiling the data. + assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] + assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] # Concatenate the episode data. - episode_data = { - "hf_dataset": concatenate_datasets( - [episode_data["hf_dataset"], this_episode_data["hf_dataset"]] - ), - "episode_data_index": { - k: torch.cat( - [ - episode_data["episode_data_index"][k], - this_episode_data["episode_data_index"][k], - ] - ) - for k in ["from", "to"] - }, - } + episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} # Maybe render video for visualization. if max_episodes_rendered > 0 and len(ep_frames) > 0: @@ -434,89 +407,39 @@ def _compile_episode_data( Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`). """ ep_dicts = [] - episode_data_index = {"from": [], "to": []} total_frames = 0 - data_index_from = start_data_index for ep_ix in range(rollout_data["action"].shape[0]): - num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame + # + 2 to include the first done frame and the last observation frame. + num_frames = done_indices[ep_ix].item() + 2 total_frames += num_frames - # TODO(rcadene): We need to add a missing last frame which is the observation - # of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state" + # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { - "action": rollout_data["action"][ep_ix, :num_frames], - "episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames), - "frame_index": torch.arange(0, num_frames, 1), - "timestamp": torch.arange(0, num_frames, 1) / fps, - "next.done": rollout_data["done"][ep_ix, :num_frames], - "next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32), + "action": rollout_data["action"][ep_ix, : num_frames - 1], + "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), + "frame_index": torch.arange(0, num_frames - 1, 1), + "timestamp": torch.arange(0, num_frames - 1, 1) / fps, + "next.done": rollout_data["done"][ep_ix, : num_frames - 1], + "next.success": rollout_data["success"][ep_ix, : num_frames - 1], + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), } + + # For the last observation frame, all other keys will just be copy padded. + for k in ep_dict: + ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) + for key in rollout_data["observation"]: - ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames] + ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + ep_dicts.append(ep_dict) - episode_data_index["from"].append(data_index_from) - episode_data_index["to"].append(data_index_from + num_frames) - - data_index_from += num_frames - data_dict = {} for key in ep_dicts[0]: - if "image" not in key: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for img in ep_dict[key]: - # sanity check that images are channel first - c, h, w = img.shape - assert c < h and c < w, f"expect channel first images, but instead {img.shape}" - - # sanity check that images are float32 in range [0,1] - assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}" - assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}" - assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}" - - # from float32 in range [0,1] to uint8 in range [0,255] - img *= 255 - img = img.type(torch.uint8) - - # convert to channel last and numpy as expected by PIL - img = PILImage.fromarray(img.permute(1, 2, 0).numpy()) - - data_dict[key].append(img) + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) - episode_data_index["from"] = torch.tensor(episode_data_index["from"]) - episode_data_index["to"] = torch.tensor(episode_data_index["to"]) - # TODO(rcadene): clean this - features = {} - for key in rollout_data["observation"]: - if "image" in key: - features[key] = Image() - else: - features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)) - features.update( - { - "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_index": Value(dtype="int64", id=None), - "frame_index": Value(dtype="int64", id=None), - "timestamp": Value(dtype="float32", id=None), - "next.reward": Value(dtype="float32", id=None), - "next.done": Value(dtype="bool", id=None), - #'next.success': Value(dtype='bool', id=None), - "index": Value(dtype="int64", id=None), - } - ) - features = Features(features) - hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset.set_transform(hf_transform_to_torch) - return { - "hf_dataset": hf_dataset, - "episode_data_index": episode_data_index, - } + return data_dict def main( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f707fe12..d8fdfc1f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,20 +15,25 @@ # limitations under the License. import logging import time +from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext +from copy import deepcopy from pathlib import Path from pprint import pformat +from threading import Lock import hydra +import numpy as np import torch from deepdiff import DeepDiff -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf from termcolor import colored from torch import nn from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset +from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env @@ -107,6 +112,7 @@ def update_policy( grad_scaler: GradScaler, lr_scheduler=None, use_amp: bool = False, + lock=None, ): """Returns a dictionary of items for logging.""" start_time = time.perf_counter() @@ -129,7 +135,8 @@ def update_policy( # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. - grad_scaler.step(optimizer) + with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) # Updates the scale for next iteration. grad_scaler.update() @@ -149,11 +156,12 @@ def update_policy( "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } + info.update({k: v for k, v in output_dict.items() if k not in info}) return info -def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): +def log_train_info(logger: Logger, info, step, cfg, dataset, is_online): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] @@ -187,12 +195,12 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs - info["is_offline"] = is_offline + info["is_online"] = is_online logger.log_dict(info, step, mode="train") -def log_eval_info(logger, info, step, cfg, dataset, is_offline): +def log_eval_info(logger, info, step, cfg, dataset, is_online): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] @@ -221,7 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs - info["is_offline"] = is_offline + info["is_online"] = is_online logger.log_dict(info, step, mode="eval") @@ -234,6 +242,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() + if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): + raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need # to check for any differences between the provided config and the checkpoint's config. if cfg.resume: @@ -279,9 +290,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # log metrics to terminal and wandb logger = Logger(cfg, out_dir, wandb_job_name=job_name) - if cfg.training.online_steps > 0: - raise NotImplementedError("Online training is not implemented yet.") - set_global_seed(cfg.seed) # Check device is available @@ -336,7 +344,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # Note: this helper will be used in offline and online training loops. - def evaluate_and_checkpoint_if_needed(step): + def evaluate_and_checkpoint_if_needed(step, is_online): _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) step_identifier = f"{step:0{_num_digits}d}" @@ -352,7 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No max_episodes_rendered=4, start_seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True) + log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") @@ -396,8 +404,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() + offline_step = 0 for _ in range(step, cfg.training.offline_steps): - if step == 0: + if offline_step == 0: logging.info("Start offline training on a fixed dataset") start_time = time.perf_counter() @@ -420,13 +429,207 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No train_info["dataloading_s"] = dataloading_s if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) + log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1) + evaluate_and_checkpoint_if_needed(step + 1, is_online=False) step += 1 + offline_step += 1 # noqa: SIM113 + + if cfg.training.online_steps == 0: + if eval_env: + eval_env.close() + logging.info("End of training") + return + + # Online training. + + # Create an env dedicated to online episodes collection from policy rollout. + online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) + resolve_delta_timestamps(cfg) + online_buffer_path = logger.log_dir / "online_buffer" + if cfg.resume and not online_buffer_path.exists(): + # If we are resuming a run, we default to the data shapes and buffer capacity from the saved online + # buffer. + logging.warning( + "When online training is resumed, we load the latest online buffer from the prior run, " + "and this might not coincide with the state of the buffer as it was at the moment the checkpoint " + "was made. This is because the online buffer is updated on disk during training, independently " + "of our explicit checkpointing mechanisms." + ) + online_dataset = OnlineBuffer( + online_buffer_path, + data_spec={ + **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()}, + **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()}, + "next.reward": {"shape": (), "dtype": np.dtype("float32")}, + "next.done": {"shape": (), "dtype": np.dtype("?")}, + "next.success": {"shape": (), "dtype": np.dtype("?")}, + }, + buffer_capacity=cfg.training.online_buffer_capacity, + fps=online_env.unwrapped.metadata["render_fps"], + delta_timestamps=cfg.training.delta_timestamps, + ) + + # If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this + # makes it possible to do online rollouts in parallel with training updates). + online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy + + # Create dataloader for online training. + concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) + sampler_weights = compute_sampler_weights( + offline_dataset, + offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), + online_dataset=online_dataset, + # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have + # this final observation in the offline datasets, but we might add them in future. + online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, + online_sampling_ratio=cfg.training.online_sampling_ratio, + ) + sampler = torch.utils.data.WeightedRandomSampler( + sampler_weights, + num_samples=len(concat_dataset), + replacement=True, + ) + dataloader = torch.utils.data.DataLoader( + concat_dataset, + batch_size=cfg.training.batch_size, + num_workers=cfg.training.num_workers, + sampler=sampler, + pin_memory=device.type != "cpu", + drop_last=True, + ) + dl_iter = cycle(dataloader) + + # Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled, + # these are still used but effectively do nothing. + lock = Lock() + # Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch + # parallelization of rollouts is handled within the job. + executor = ThreadPoolExecutor(max_workers=1) + + online_step = 0 + online_rollout_s = 0 # time take to do online rollout + update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data + # Time taken waiting for the online buffer to finish being updated. This is relevant when using the async + # online rollout option. + await_update_online_buffer_s = 0 + rollout_start_seed = cfg.training.online_env_seed + + while True: + if online_step == cfg.training.online_steps: + break + + if online_step == 0: + logging.info("Start online training by interacting with environment") + + def sample_trajectory_and_update_buffer(): + nonlocal rollout_start_seed + with lock: + online_rollout_policy.load_state_dict(policy.state_dict()) + online_rollout_policy.eval() + start_rollout_time = time.perf_counter() + with torch.no_grad(): + eval_info = eval_policy( + online_env, + online_rollout_policy, + n_episodes=cfg.training.online_rollout_n_episodes, + max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes), + videos_dir=logger.log_dir / "online_rollout_videos", + return_episode_data=True, + start_seed=( + rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000 + ), + ) + online_rollout_s = time.perf_counter() - start_rollout_time + + with lock: + start_update_buffer_time = time.perf_counter() + online_dataset.add_data(eval_info["episodes"]) + + # Update the concatenated dataset length used during sampling. + concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + + # Update the sampling weights. + sampler.weights = compute_sampler_weights( + offline_dataset, + offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), + online_dataset=online_dataset, + # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have + # this final observation in the offline datasets, but we might add them in future. + online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, + online_sampling_ratio=cfg.training.online_sampling_ratio, + ) + sampler.num_samples = len(concat_dataset) + + update_online_buffer_s = time.perf_counter() - start_update_buffer_time + + return online_rollout_s, update_online_buffer_s + + future = executor.submit(sample_trajectory_and_update_buffer) + # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait + # here until the rollout and buffer update is done, before proceeding to the policy update steps. + if ( + not cfg.training.do_online_rollout_async + or len(online_dataset) <= cfg.training.online_buffer_seed_size + ): + online_rollout_s, update_online_buffer_s = future.result() + + if len(online_dataset) <= cfg.training.online_buffer_seed_size: + logging.info( + f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}" + ) + continue + + policy.train() + for _ in range(cfg.training.online_steps_between_rollouts): + with lock: + start_time = time.perf_counter() + batch = next(dl_iter) + dataloading_s = time.perf_counter() - start_time + + for key in batch: + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + train_info = update_policy( + policy, + batch, + optimizer, + cfg.training.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.use_amp, + lock=lock, + ) + + train_info["dataloading_s"] = dataloading_s + train_info["online_rollout_s"] = online_rollout_s + train_info["update_online_buffer_s"] = update_online_buffer_s + train_info["await_update_online_buffer_s"] = await_update_online_buffer_s + with lock: + train_info["online_buffer_size"] = len(online_dataset) + + if step % cfg.training.log_freq == 0: + log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True) + + # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, + # so we pass in step + 1. + evaluate_and_checkpoint_if_needed(step + 1, is_online=True) + + step += 1 + online_step += 1 + + # If we're doing async rollouts, we should now wait until we've completed them before proceeding + # to do the next batch of rollouts. + if future.running(): + start = time.perf_counter() + online_rollout_s, update_online_buffer_s = future.result() + await_update_online_buffer_s = time.perf_counter() - start + + if online_step >= cfg.training.online_steps: + break if eval_env: eval_env.close() diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors deleted file mode 100644 index b4fe1140..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:376c501d2780c7204850b58210a5a9476347cf9b5afb8f45b185d23ad6b5be4d -size 928 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors deleted file mode 100644 index a9b41bd7..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:213310c6fceffa9fd31066b87b9305484ff7289051a67f3a2490c39640fe7e28 -size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors deleted file mode 100644 index 724e8ae0..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2499552badd9201bc73bfa91ba881f43cecdfd93c6f3ca14d3aaf753828a0f4d -size 240 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors deleted file mode 100644 index 3b6cd374..00000000 --- a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9ff50ee7b750022e17d867ff20307259da8d99a48fff440499e8ca9b4cf42a4a -size 36312 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors new file mode 100644 index 00000000..e2fb68ac --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a +size 472 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors new file mode 100644 index 00000000..cf756229 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0 +size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors new file mode 100644 index 00000000..f8863cfb --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee +size 240 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors new file mode 100644 index 00000000..8ce3c4f3 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_mpc/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6 +size 36312 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors new file mode 100644 index 00000000..1b3912ed --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1 +size 472 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors new file mode 100644 index 00000000..cf756229 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0 +size 16904 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors new file mode 100644 index 00000000..f8863cfb --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee +size 240 diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors new file mode 100644 index 00000000..8ce3c4f3 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/xarm_tdmpcuse_policy/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6 +size 36312 diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 52c1c520..5236b7ae 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -108,7 +108,8 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": env_policies = [ - # ("xarm", "tdmpc", ["policy.use_mpc=false"], ""), + # ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"), + # ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"), # ( # "pusht", # "diffusion", diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py new file mode 100644 index 00000000..37000e4f --- /dev/null +++ b/tests/test_online_buffer.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.d +from copy import deepcopy +from uuid import uuid4 + +import numpy as np +import pytest +import torch +from datasets import Dataset + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights +from lerobot.common.datasets.utils import hf_transform_to_torch + +# Some constants for OnlineBuffer tests. +data_key = "data" +data_shape = (2, 3) # just some arbitrary > 1D shape +buffer_capacity = 100 +fps = 10 + + +def make_new_buffer( + write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None +) -> tuple[OnlineBuffer, str]: + if write_dir is None: + write_dir = f"/tmp/online_buffer_{uuid4().hex}" + buffer = OnlineBuffer( + write_dir, + data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}}, + buffer_capacity=buffer_capacity, + fps=fps, + delta_timestamps=delta_timestamps, + ) + return buffer, write_dir + + +def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: + new_data = { + data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), + OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), + OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), + OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), + OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), + } + return new_data + + +def test_non_mutate(): + """Checks that the data provided to the add_data method is copied rather than passed by reference. + + This means that mutating the data in the buffer does not mutate the original data. + + NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust + a success case for `test_write_read`. + """ + buffer, _ = make_new_buffer() + new_data = make_spoof_data_frames(2, buffer_capacity // 4) + new_data_copy = deepcopy(new_data) + buffer.add_data(new_data) + buffer._data[data_key][:] += 1 + assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data) + + +def test_index_error_no_data(): + buffer, _ = make_new_buffer() + with pytest.raises(IndexError): + buffer[0] + + +def test_index_error_with_data(): + buffer, _ = make_new_buffer() + n_frames = buffer_capacity // 2 + new_data = make_spoof_data_frames(1, n_frames) + buffer.add_data(new_data) + with pytest.raises(IndexError): + buffer[n_frames] + with pytest.raises(IndexError): + buffer[-n_frames - 1] + + +@pytest.mark.parametrize("do_reload", [False, True]) +def test_write_read(do_reload: bool): + """Checks that data can be added to the buffer and read back. + + If do_reload we delete the buffer object and load the buffer back from disk before reading. + """ + buffer, write_dir = make_new_buffer() + n_episodes = 2 + n_frames_per_episode = buffer_capacity // 4 + new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) + buffer.add_data(new_data) + + if do_reload: + del buffer + buffer, _ = make_new_buffer(write_dir) + + assert len(buffer) == n_frames_per_episode * n_episodes + for i, item in enumerate(buffer): + assert all(isinstance(item[k], torch.Tensor) for k in item) + assert np.array_equal(item[data_key].numpy(), new_data[data_key][i]) + + +def test_read_data_key(): + """Tests that data can be added to a buffer and all data for a. specific key can be read back.""" + buffer, _ = make_new_buffer() + n_episodes = 2 + n_frames_per_episode = buffer_capacity // 4 + new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) + buffer.add_data(new_data) + + data_from_buffer = buffer.get_data_by_key(data_key) + assert isinstance(data_from_buffer, torch.Tensor) + assert np.array_equal(data_from_buffer.numpy(), new_data[data_key]) + + +def test_fifo(): + """Checks that if data is added beyond the buffer capacity, we discard the oldest data first.""" + buffer, _ = make_new_buffer() + n_frames_per_episode = buffer_capacity // 4 + n_episodes = 3 + new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) + buffer.add_data(new_data) + n_more_episodes = 2 + # Developer sanity check (in case someone changes the global `buffer_capacity`). + assert ( + n_episodes + n_more_episodes + ) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code." + more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) + buffer.add_data(more_new_data) + assert len(buffer) == buffer_capacity, "The buffer should be full." + + expected_data = {} + for k in new_data: + # Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer. + expected_data[k] = np.roll( + np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:], + shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity, + axis=0, + ) + + for i, item in enumerate(buffer): + assert all(isinstance(item[k], torch.Tensor) for k in item) + assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i]) + + +def test_delta_timestamps_within_tolerance(): + """Check that getting an item with delta_timestamps within tolerance succeeds. + + Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`. + """ + # Sanity check on global fps as we are assuming it is 10 here. + assert fps == 10, "This test assumes fps==10" + buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]}) + new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) + buffer.add_data(new_data) + buffer.tolerance_s = 0.04 + item = buffer[2] + data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] + assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values" + assert not is_pad.any(), "Unexpected padding detected" + + +def test_delta_timestamps_outside_tolerance_inside_episode_range(): + """Check that getting an item with delta_timestamps outside of tolerance fails. + + We expect it to fail if and only if the requested timestamps are within the episode range. + + Note: Copied from + `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range` + """ + # Sanity check on global fps as we are assuming it is 10 here. + assert fps == 10, "This test assumes fps==10" + buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]}) + new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) + buffer.add_data(new_data) + buffer.tolerance_s = 0.04 + with pytest.raises(AssertionError): + buffer[2] + + +def test_delta_timestamps_outside_tolerance_outside_episode_range(): + """Check that copy-padding of timestamps outside of the episode range works. + + Note: Copied from + `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range` + """ + # Sanity check on global fps as we are assuming it is 10 here. + assert fps == 10, "This test assumes fps==10" + buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]}) + new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) + buffer.add_data(new_data) + buffer.tolerance_s = 0.04 + item = buffer[2] + data, is_pad = item["index"], item["index_is_pad"] + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" + assert torch.equal( + is_pad, torch.tensor([True, False, False, True, True]) + ), "Padding does not match expected values" + + +# Arbitrarily set small dataset sizes, making sure to have uneven sizes. +@pytest.mark.parametrize("offline_dataset_size", [0, 6]) +@pytest.mark.parametrize("online_dataset_size", [0, 4]) +@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) +def test_compute_sampler_weights_trivial( + offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float +): + # Pass/skip the test if both datasets sizes are zero. + if offline_dataset_size + online_dataset_size == 0: + return + # Create spoof offline dataset. + offline_dataset = LeRobotDataset.from_preloaded( + hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))}) + ) + offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) + if offline_dataset_size == 0: + offline_dataset.episode_data_index = {} + else: + # Set up an episode_data_index with at least two episodes. + offline_dataset.episode_data_index = { + "from": torch.tensor([0, offline_dataset_size // 2]), + "to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]), + } + # Create spoof online datset. + online_dataset, _ = make_new_buffer() + if online_dataset_size > 0: + online_dataset.add_data( + make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) + ) + + weights = compute_sampler_weights( + offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio + ) + if offline_dataset_size == 0 or online_dataset_size == 0: + expected_weights = torch.ones(offline_dataset_size + online_dataset_size) + elif online_sampling_ratio == 0: + expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) + elif online_sampling_ratio == 1: + expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) + expected_weights /= expected_weights.sum() + assert torch.allclose(weights, expected_weights) + + +def test_compute_sampler_weights_nontrivial_ratio(): + # Arbitrarily set small dataset sizes, making sure to have uneven sizes. + # Create spoof offline dataset. + offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))})) + offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) + offline_dataset.episode_data_index = { + "from": torch.tensor([0, 2]), + "to": torch.tensor([2, 4]), + } + # Create spoof online datset. + online_dataset, _ = make_new_buffer() + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + online_sampling_ratio = 0.8 + weights = compute_sampler_weights( + offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio + ) + assert torch.allclose( + weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + ) + + +def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(): + # Arbitrarily set small dataset sizes, making sure to have uneven sizes. + # Create spoof offline dataset. + offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))})) + offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) + offline_dataset.episode_data_index = { + "from": torch.tensor([0]), + "to": torch.tensor([4]), + } + # Create spoof online datset. + online_dataset, _ = make_new_buffer() + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + weights = compute_sampler_weights( + offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 + ) + assert torch.allclose( + weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) + ) + + +def test_compute_sampler_weights_drop_n_last_frames(): + """Note: test copied from test_sampler.""" + data_dict = { + "timestamp": [0, 0.1], + "index": [0, 1], + "episode_index": [0, 0], + "frame_index": [0, 1], + } + offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict)) + offline_dataset.hf_dataset.set_transform(hf_transform_to_torch) + offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])} + + online_dataset, _ = make_new_buffer() + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + + weights = compute_sampler_weights( + offline_dataset, + offline_drop_n_last_frames=1, + online_dataset=online_dataset, + online_sampling_ratio=0.5, + online_drop_n_last_frames=1, + ) + assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/test_policies.py b/tests/test_policies.py index d9b946ab..d90f0071 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -357,7 +357,8 @@ def test_normalize(insert_temporal_dim): # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. - ("xarm", "tdmpc", ["policy.use_mpc=false"], ""), + ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"), + # ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"), ( "pusht", "diffusion",