Compare commits
12 Commits
7e0c7b07fb
...
b88fa1d666
Author | SHA1 | Date |
---|---|---|
|
b88fa1d666 | |
|
b43ece8934 | |
|
c10c5a0e64 | |
|
a8db91c40e | |
|
0f5f7ac780 | |
|
ef8579eacd | |
|
36ee3b50b6 | |
|
86e75ab7f8 | |
|
5d7a0ce32e | |
|
334d9e92bd | |
|
e82b4c9460 | |
|
d4e7b355a9 |
|
@ -116,7 +116,7 @@ pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||||
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||||
|
|
||||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||||
|
|
|
@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
tcpdump sysstat screen tmux \
|
tcpdump sysstat screen tmux \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install ffmpeg build dependencies. See:
|
# Install ffmpeg build dependencies. See:
|
||||||
|
|
|
@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
|
||||||
|
|
||||||
## The training script
|
## The training script
|
||||||
|
|
||||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
|
||||||
|
|
||||||
- Initialize/load a configuration for the following steps using.
|
- Initialize/load a configuration for the following steps using.
|
||||||
- Instantiates a dataset.
|
- Instantiates a dataset.
|
||||||
|
@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
|
||||||
def train(cfg: TrainPipelineConfig):
|
def train(cfg: TrainPipelineConfig):
|
||||||
```
|
```
|
||||||
|
|
||||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||||
|
|
||||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||||
|
|
||||||
## Specifying values from the CLI
|
## Specifying values from the CLI
|
||||||
|
|
||||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--dataset.repo_id=lerobot/pusht \
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
|
||||||
|
|
||||||
Let's break this down:
|
Let's break this down:
|
||||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
|
||||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
|
||||||
|
|
||||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
|
@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
|
||||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||||
|
|
||||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=act \
|
--policy.type=act \
|
||||||
|
|
|
@ -0,0 +1,212 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.common.optim.optimizers import AdamConfig
|
||||||
|
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("ditflow")
|
||||||
|
@dataclass
|
||||||
|
class DiTFlowConfig(PreTrainedConfig):
|
||||||
|
"""Configuration class for DiTFlowPolicy.
|
||||||
|
|
||||||
|
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||||
|
|
||||||
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- Either:
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
AND/OR
|
||||||
|
- The key "observation.environment_state" is required as input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views. Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
|
current step and additional steps going back).
|
||||||
|
horizon: DiT-flow model action prediction size as detailed in `DiTFlowPolicy.select_action`.
|
||||||
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
|
See `DiTFlowPolicy.select_action` for more details.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
[-1, 1] range.
|
||||||
|
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||||
|
original scale. Note that this is also used for normalizing the training targets.
|
||||||
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
|
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||||
|
within the image size. If None, no cropping is done.
|
||||||
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
|
mode).
|
||||||
|
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||||
|
`None` means no pretrained weights.
|
||||||
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
|
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||||
|
|
||||||
|
frequency_embedding_dim: The embedding dimension for the time value embedding in the flow model.
|
||||||
|
num_blocks: The number of transformer blocks in the DiT flow model.
|
||||||
|
hidden_dim: The hidden dimension for the transformer blocks in the DiT flow model.
|
||||||
|
num_heads: The number of attention heads in the transformer blocks.
|
||||||
|
dropout: The dropout rate used inside the transformer blocks.
|
||||||
|
dim_feedforward: The expanded feedforward dimension in the MLPs used in the transformer block.
|
||||||
|
activation: The activation function used in the transformer blocks.
|
||||||
|
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||||
|
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||||
|
normalized to fit within this range.
|
||||||
|
clip_sample_range: The magnitude of the clipping range as described above.
|
||||||
|
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||||
|
spaced).
|
||||||
|
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
|
||||||
|
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
|
||||||
|
to False as the original Diffusion Policy implementation does the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Inputs / output structure.
|
||||||
|
n_obs_steps: int = 2
|
||||||
|
horizon: int = 16
|
||||||
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
|
# which avoids excessive padding and leads to improved training results.
|
||||||
|
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: str = "resnet18"
|
||||||
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
|
crop_is_random: bool = True
|
||||||
|
pretrained_backbone_weights: str | None = None
|
||||||
|
use_group_norm: bool = True
|
||||||
|
spatial_softmax_num_keypoints: int = 32
|
||||||
|
use_separate_rgb_encoder_per_camera: bool = False
|
||||||
|
|
||||||
|
# Diffusion Transformer (DiT) parameters.
|
||||||
|
frequency_embedding_dim: int = 256
|
||||||
|
hidden_dim: int = 512
|
||||||
|
num_blocks: int = 6
|
||||||
|
num_heads: int = 16
|
||||||
|
dropout: float = 0.1
|
||||||
|
dim_feedforward: int = 4096
|
||||||
|
activation: str = "gelu"
|
||||||
|
|
||||||
|
# Noise scheduler.
|
||||||
|
training_noise_sampling: str = (
|
||||||
|
"uniform" # "uniform" or "beta", from pi0 https://www.physicalintelligence.company/download/pi0.pdf
|
||||||
|
)
|
||||||
|
clip_sample: bool = True
|
||||||
|
clip_sample_range: float = 1.0
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
num_inference_steps: int | None = 100
|
||||||
|
|
||||||
|
# Loss computation
|
||||||
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
|
# Training presets
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple = (0.95, 0.999)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-6
|
||||||
|
scheduler_name: str = "cosine"
|
||||||
|
scheduler_warmup_steps: int = 500
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
"""Input validation (not exhaustive)."""
|
||||||
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
|
raise ValueError(
|
||||||
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.training_noise_sampling not in ("uniform", "beta"):
|
||||||
|
raise ValueError(
|
||||||
|
f"`training_noise_sampling` must be either 'uniform' or 'beta'. Got {self.training_noise_sampling}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
|
return AdamConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||||
|
return DiffuserSchedulerConfig(
|
||||||
|
name=self.scheduler_name,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||||
|
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||||
|
|
||||||
|
if self.crop_shape is not None:
|
||||||
|
for key, image_ft in self.image_features.items():
|
||||||
|
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||||
|
f"for `crop_shape` and {image_ft.shape} for "
|
||||||
|
f"`{key}`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that all input images have the same shape.
|
||||||
|
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||||
|
for key, image_ft in self.image_features.items():
|
||||||
|
if image_ft.shape != first_image_ft.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list:
|
||||||
|
return list(range(1 - self.n_obs_steps, 1))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list:
|
||||||
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
|
@ -0,0 +1,587 @@
|
||||||
|
# Copyright 2025 Nur Muhammad Mahi Shafiullah,
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Heavy inspiration taken from
|
||||||
|
# * DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr
|
||||||
|
# * DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT
|
||||||
|
# * DiT Policy by Dasari et. al. : https://github.com/sudeepdasari/dit-policy
|
||||||
|
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
|
||||||
|
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionRgbEncoder
|
||||||
|
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
||||||
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.utils import (
|
||||||
|
get_device_from_parameters,
|
||||||
|
get_dtype_from_parameters,
|
||||||
|
populate_queues,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_activation_fn(activation: str):
|
||||||
|
"""Return an activation function given a string"""
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
if activation == "gelu":
|
||||||
|
return nn.GELU(approximate="tanh")
|
||||||
|
if activation == "glu":
|
||||||
|
return F.glu
|
||||||
|
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class _TimeNetwork(nn.Module):
|
||||||
|
def __init__(self, frequency_embedding_dim, hidden_dim, learnable_w=False, max_period=1000):
|
||||||
|
assert frequency_embedding_dim % 2 == 0, "time_dim must be even!"
|
||||||
|
half_dim = int(frequency_embedding_dim // 2)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
w = np.log(max_period) / (half_dim - 1)
|
||||||
|
w = torch.exp(torch.arange(half_dim) * -w).float()
|
||||||
|
self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w))
|
||||||
|
|
||||||
|
self.out_net = nn.Sequential(
|
||||||
|
nn.Linear(frequency_embedding_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
assert len(t.shape) == 1, "assumes 1d input timestep array"
|
||||||
|
t = t[:, None] * self.w[None]
|
||||||
|
t = torch.cat((torch.cos(t), torch.sin(t)), dim=1)
|
||||||
|
return self.out_net(t)
|
||||||
|
|
||||||
|
|
||||||
|
class _ShiftScaleMod(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.scale = nn.Linear(dim, dim)
|
||||||
|
self.shift = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
c = self.act(c)
|
||||||
|
return x * (1 + self.scale(c)[None]) + self.shift(c)[None]
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.zeros_(self.scale.weight)
|
||||||
|
nn.init.zeros_(self.shift.weight)
|
||||||
|
nn.init.zeros_(self.scale.bias)
|
||||||
|
nn.init.zeros_(self.shift.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class _ZeroScaleMod(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.scale = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
c = self.act(c)
|
||||||
|
return x * self.scale(c)[None]
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.zeros_(self.scale.weight)
|
||||||
|
nn.init.zeros_(self.scale.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class _DiTDecoder(nn.Module):
|
||||||
|
def __init__(self, d_model=256, nhead=6, dim_feedforward=2048, dropout=0.0, activation="gelu"):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
|
||||||
|
# create mlp
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
self.linear1,
|
||||||
|
self.activation,
|
||||||
|
self.dropout2,
|
||||||
|
self.linear2,
|
||||||
|
self.dropout3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create modulation layers
|
||||||
|
self.attn_modulate = _ShiftScaleMod(d_model)
|
||||||
|
self.attn_gate = _ZeroScaleMod(d_model)
|
||||||
|
self.mlp_modulate = _ShiftScaleMod(d_model)
|
||||||
|
self.mlp_gate = _ZeroScaleMod(d_model)
|
||||||
|
|
||||||
|
def forward(self, x, t, cond):
|
||||||
|
# process the conditioning vector first
|
||||||
|
cond = cond + t
|
||||||
|
|
||||||
|
x2 = self.attn_modulate(self.norm1(x), cond)
|
||||||
|
x2, _ = self.self_attn(x2, x2, x2, need_weights=False)
|
||||||
|
x = x + self.attn_gate(self.dropout1(x2), cond)
|
||||||
|
|
||||||
|
x3 = self.mlp_modulate(self.norm2(x), cond)
|
||||||
|
x3 = self.mlp(x3)
|
||||||
|
x3 = self.mlp_gate(x3, cond)
|
||||||
|
return x + x3
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
for s in (self.attn_modulate, self.attn_gate, self.mlp_modulate, self.mlp_gate):
|
||||||
|
s.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
class _FinalLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_size):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, out_size, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x, t, cond):
|
||||||
|
# process the conditioning vector first
|
||||||
|
cond = cond + t
|
||||||
|
|
||||||
|
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1)
|
||||||
|
x = modulate(x, shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for p in self.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, base_module, num_layers):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([copy.deepcopy(base_module) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
layer.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, src, t, cond):
|
||||||
|
x = src
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, t, cond)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _DiTNoiseNet(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ac_dim,
|
||||||
|
ac_chunk,
|
||||||
|
cond_dim,
|
||||||
|
time_dim=256,
|
||||||
|
hidden_dim=256,
|
||||||
|
num_blocks=6,
|
||||||
|
dropout=0.1,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
nhead=8,
|
||||||
|
activation="gelu",
|
||||||
|
clip_sample=False,
|
||||||
|
clip_sample_range=1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ac_dim, self.ac_chunk = ac_dim, ac_chunk
|
||||||
|
|
||||||
|
# positional encoding blocks
|
||||||
|
self.register_parameter(
|
||||||
|
"dec_pos",
|
||||||
|
nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True),
|
||||||
|
)
|
||||||
|
nn.init.xavier_uniform_(self.dec_pos.data)
|
||||||
|
|
||||||
|
# input encoder mlps
|
||||||
|
self.time_net = _TimeNetwork(time_dim, hidden_dim)
|
||||||
|
self.ac_proj = nn.Sequential(
|
||||||
|
nn.Linear(ac_dim, ac_dim),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(ac_dim, hidden_dim),
|
||||||
|
)
|
||||||
|
self.cond_proj = nn.Linear(cond_dim, hidden_dim)
|
||||||
|
|
||||||
|
# decoder blocks
|
||||||
|
decoder_module = _DiTDecoder(
|
||||||
|
hidden_dim,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
dropout=dropout,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
self.decoder = _TransformerDecoder(decoder_module, num_blocks)
|
||||||
|
|
||||||
|
# turns predicted tokens into epsilons
|
||||||
|
self.eps_out = _FinalLayer(hidden_dim, ac_dim)
|
||||||
|
|
||||||
|
# clip the output samples
|
||||||
|
self.clip_sample = clip_sample
|
||||||
|
self.clip_sample_range = clip_sample_range
|
||||||
|
|
||||||
|
print("Number of flow params: {:.2f}M".format(sum(p.numel() for p in self.parameters()) / 1e6))
|
||||||
|
|
||||||
|
def forward(self, noisy_actions, time, global_cond):
|
||||||
|
c = self.cond_proj(global_cond)
|
||||||
|
time_enc = self.time_net(time)
|
||||||
|
|
||||||
|
ac_tokens = self.ac_proj(noisy_actions) # [B, T, adim] -> [B, T, hidden_dim]
|
||||||
|
ac_tokens = ac_tokens.transpose(0, 1) # [B, T, hidden_dim] -> [T, B, hidden_dim]
|
||||||
|
|
||||||
|
# Allow variable length action chunks
|
||||||
|
dec_in = ac_tokens + self.dec_pos[: ac_tokens.size(0)] # [T, B, hidden_dim]
|
||||||
|
|
||||||
|
# apply decoder
|
||||||
|
dec_out = self.decoder(dec_in, time_enc, c)
|
||||||
|
|
||||||
|
# apply final epsilon prediction layer
|
||||||
|
eps_out = self.eps_out(dec_out, time_enc, c) # [T, B, hidden_dim] -> [T, B, adim]
|
||||||
|
return eps_out.transpose(0, 1) # [T, B, adim] -> [B, T, adim]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(
|
||||||
|
self, condition: torch.Tensor, timesteps: int = 100, generator: torch.Generator | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Use Euler integration to solve the ODE.
|
||||||
|
batch_size, device = condition.shape[0], condition.device
|
||||||
|
x_0 = self.sample_noise(batch_size, device, generator)
|
||||||
|
dt = 1.0 / timesteps
|
||||||
|
t_all = (
|
||||||
|
torch.arange(timesteps, device=device).float().unsqueeze(0).expand(batch_size, timesteps)
|
||||||
|
/ timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in range(timesteps):
|
||||||
|
t = t_all[:, k]
|
||||||
|
x_0 = x_0 + dt * self.forward(x_0, t, condition)
|
||||||
|
if self.clip_sample:
|
||||||
|
x_0 = torch.clamp(x_0, -self.clip_sample_range, self.clip_sample_range)
|
||||||
|
return x_0
|
||||||
|
|
||||||
|
def sample_noise(self, batch_size: int, device, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||||
|
return torch.randn(batch_size, self.ac_chunk, self.ac_dim, device=device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
class DiTFlowPolicy(PreTrainedPolicy):
|
||||||
|
"""
|
||||||
|
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||||
|
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = DiTFlowConfig
|
||||||
|
name = "DiTFlow"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DiTFlowConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||||
|
the configuration class is used.
|
||||||
|
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||||
|
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||||
|
self.normalize_targets = Normalize(
|
||||||
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
|
)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
|
self.dit_flow = DiTFlowModel(config)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.dit_flow.parameters()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
|
self._queues = {
|
||||||
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
|
"action": deque(maxlen=self.config.n_action_steps),
|
||||||
|
}
|
||||||
|
if self.config.image_features:
|
||||||
|
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
|
underlying flow model. Here's how it works:
|
||||||
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
|
copied `n_obs_steps` times to fill the cache).
|
||||||
|
- The flow model generates `horizon` steps worth of actions.
|
||||||
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
|
Schematically this looks like:
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
|
----------------------------------------------------------------------------------------------
|
||||||
|
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||||
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
|
"""
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch["observation.images"] = torch.stack(
|
||||||
|
[batch[key] for key in self.config.image_features], dim=-4
|
||||||
|
)
|
||||||
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
actions = self.dit_flow.generate_actions(batch)
|
||||||
|
|
||||||
|
# TODO(rcadene): make above methods return output dictionary?
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
|
return action
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch["observation.images"] = torch.stack(
|
||||||
|
[batch[key] for key in self.config.image_features], dim=-4
|
||||||
|
)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
loss = self.dit_flow.compute_loss(batch)
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
class DiTFlowModel(nn.Module):
|
||||||
|
def __init__(self, config: DiTFlowConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Build observation encoders (depending on which observations are provided).
|
||||||
|
global_cond_dim = self.config.robot_state_feature.shape[0]
|
||||||
|
if self.config.image_features:
|
||||||
|
num_images = len(self.config.image_features)
|
||||||
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
|
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||||
|
self.rgb_encoder = nn.ModuleList(encoders)
|
||||||
|
global_cond_dim += encoders[0].feature_dim * num_images
|
||||||
|
else:
|
||||||
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||||
|
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
global_cond_dim += self.config.env_state_feature.shape[0]
|
||||||
|
|
||||||
|
self.velocity_net = _DiTNoiseNet(
|
||||||
|
ac_dim=config.action_feature.shape[0],
|
||||||
|
ac_chunk=config.horizon,
|
||||||
|
cond_dim=global_cond_dim * config.n_obs_steps,
|
||||||
|
time_dim=config.frequency_embedding_dim,
|
||||||
|
hidden_dim=config.hidden_dim,
|
||||||
|
num_blocks=config.num_blocks,
|
||||||
|
dropout=config.dropout,
|
||||||
|
dim_feedforward=config.dim_feedforward,
|
||||||
|
nhead=config.num_heads,
|
||||||
|
activation=config.activation,
|
||||||
|
clip_sample=config.clip_sample,
|
||||||
|
clip_sample_range=config.clip_sample_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_inference_steps = config.num_inference_steps or 100
|
||||||
|
self.training_noise_sampling = config.training_noise_sampling
|
||||||
|
if config.training_noise_sampling == "uniform":
|
||||||
|
self.noise_distribution = torch.distributions.Uniform(
|
||||||
|
low=0,
|
||||||
|
high=1,
|
||||||
|
)
|
||||||
|
elif config.training_noise_sampling == "beta":
|
||||||
|
# From the Pi0 paper, https://www.physicalintelligence.company/download/pi0.pdf Appendix B.
|
||||||
|
# There, they say the PDF for the distribution they use is the following:
|
||||||
|
# $p(t) = Beta((s-t) / s; 1.5, 1)$
|
||||||
|
# So, we first figure out the distribution over $t'$ and then transform it to $t = s - s * t'$.
|
||||||
|
s = 0.999 # constant from the paper
|
||||||
|
beta_dist = torch.distributions.Beta(
|
||||||
|
concentration1=1.5, # alpha
|
||||||
|
concentration0=1.0, # beta
|
||||||
|
)
|
||||||
|
affine_transform = torch.distributions.transforms.AffineTransform(loc=s, scale=-s)
|
||||||
|
self.noise_distribution = torch.distributions.TransformedDistribution(
|
||||||
|
beta_dist, [affine_transform]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
global_cond: torch.Tensor | None = None,
|
||||||
|
generator: torch.Generator | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = get_device_from_parameters(self)
|
||||||
|
dtype = get_dtype_from_parameters(self)
|
||||||
|
|
||||||
|
# Expand global conditioning to the batch size.
|
||||||
|
if global_cond is not None:
|
||||||
|
global_cond = global_cond.expand(batch_size, -1).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Sample prior.
|
||||||
|
sample = self.velocity_net.sample(
|
||||||
|
global_cond, timesteps=self.num_inference_steps, generator=generator
|
||||||
|
)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def _prepare_global_conditioning(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
|
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||||
|
global_cond_feats = [batch[OBS_ROBOT]]
|
||||||
|
# Extract image features.
|
||||||
|
if self.config.image_features:
|
||||||
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
|
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||||
|
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||||
|
img_features_list = torch.cat(
|
||||||
|
[
|
||||||
|
encoder(images)
|
||||||
|
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||||
|
img_features = self.rgb_encoder(
|
||||||
|
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||||
|
)
|
||||||
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
global_cond_feats.append(img_features)
|
||||||
|
|
||||||
|
if self.config.env_state_feature:
|
||||||
|
global_cond_feats.append(batch[OBS_ENV])
|
||||||
|
|
||||||
|
# Concatenate features then flatten to (B, global_cond_dim).
|
||||||
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
|
def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have:
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, environment_dim)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
|
|
||||||
|
# run sampling
|
||||||
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
|
start = n_obs_steps - 1
|
||||||
|
end = start + self.config.n_action_steps
|
||||||
|
actions = actions[:, start:end]
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function expects `batch` to have (at least):
|
||||||
|
{
|
||||||
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, environment_dim)
|
||||||
|
|
||||||
|
"action": (B, horizon, action_dim)
|
||||||
|
"action_is_pad": (B, horizon)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Input validation.
|
||||||
|
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||||
|
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||||
|
n_obs_steps = batch["observation.state"].shape[1]
|
||||||
|
horizon = batch["action"].shape[1]
|
||||||
|
assert horizon == self.config.horizon
|
||||||
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
|
|
||||||
|
# Forward diffusion.
|
||||||
|
trajectory = batch["action"]
|
||||||
|
# Sample noise to add to the trajectory.
|
||||||
|
noise = self.velocity_net.sample_noise(trajectory.shape[0], trajectory.device)
|
||||||
|
# Sample a random noising timestep for each item in the batch.
|
||||||
|
timesteps = self.noise_distribution.sample((trajectory.shape[0],)).to(trajectory.device)
|
||||||
|
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||||
|
noisy_trajectory = (1 - timesteps[:, None, None]) * noise + timesteps[:, None, None] * trajectory
|
||||||
|
|
||||||
|
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||||
|
pred = self.velocity_net(noisy_actions=noisy_trajectory, time=timesteps, global_cond=global_cond)
|
||||||
|
target = trajectory - noise
|
||||||
|
loss = F.mse_loss(pred, target, reduction="none")
|
||||||
|
|
||||||
|
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||||
|
if self.config.do_mask_loss_for_padding:
|
||||||
|
if "action_is_pad" not in batch:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to provide 'action_is_pad' in the batch when "
|
||||||
|
f"{self.config.do_mask_loss_for_padding=}."
|
||||||
|
)
|
||||||
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
|
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||||
|
|
||||||
|
return loss.mean()
|
|
@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.common.policies.dit_flow.configuration_dit_flow import DiTFlowConfig
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
@ -43,6 +44,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
return DiffusionPolicy
|
return DiffusionPolicy
|
||||||
|
elif name == "ditflow":
|
||||||
|
from lerobot.common.policies.dit_flow.modeling_dit_flow import DiTFlowPolicy
|
||||||
|
|
||||||
|
return DiTFlowPolicy
|
||||||
elif name == "act":
|
elif name == "act":
|
||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
|
|
||||||
|
@ -68,6 +73,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
return TDMPCConfig(**kwargs)
|
return TDMPCConfig(**kwargs)
|
||||||
elif policy_type == "diffusion":
|
elif policy_type == "diffusion":
|
||||||
return DiffusionConfig(**kwargs)
|
return DiffusionConfig(**kwargs)
|
||||||
|
elif policy_type == "ditflow":
|
||||||
|
return DiTFlowConfig(**kwargs)
|
||||||
elif policy_type == "act":
|
elif policy_type == "act":
|
||||||
return ACTConfig(**kwargs)
|
return ACTConfig(**kwargs)
|
||||||
elif policy_type == "vqbet":
|
elif policy_type == "vqbet":
|
||||||
|
|
|
@ -512,13 +512,13 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--width",
|
"--width",
|
||||||
type=str,
|
type=int,
|
||||||
default=640,
|
default=640,
|
||||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--height",
|
"--height",
|
||||||
type=str,
|
type=int,
|
||||||
default=480,
|
default=480,
|
||||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||||
)
|
)
|
||||||
|
|
|
@ -492,13 +492,13 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--width",
|
"--width",
|
||||||
type=str,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--height",
|
"--height",
|
||||||
type=str,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||||
)
|
)
|
||||||
|
|
|
@ -174,7 +174,10 @@ def run_server(
|
||||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||||
]
|
]
|
||||||
videos_info = [
|
videos_info = [
|
||||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
{
|
||||||
|
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||||
|
"filename": video_path.parent.name,
|
||||||
|
}
|
||||||
for video_path in video_paths
|
for video_path in video_paths
|
||||||
]
|
]
|
||||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||||
|
@ -381,7 +384,7 @@ def visualize_dataset_html(
|
||||||
if isinstance(dataset, LeRobotDataset):
|
if isinstance(dataset, LeRobotDataset):
|
||||||
ln_videos_dir = static_dir / "videos"
|
ln_videos_dir = static_dir / "videos"
|
||||||
if not ln_videos_dir.exists():
|
if not ln_videos_dir.exists():
|
||||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||||
|
|
Loading…
Reference in New Issue