Compare commits
13 Commits
e110c1b798
...
ab2bb51199
Author | SHA1 | Date |
---|---|---|
|
ab2bb51199 | |
|
b43ece8934 | |
|
c10c5a0e64 | |
|
a8db91c40e | |
|
0f5f7ac780 | |
|
b1c1d395c1 | |
|
a0510c0f5e | |
|
31788f65dd | |
|
41ebc1bfb3 | |
|
9726f20661 | |
|
ea81279964 | |
|
ffa5d9e96e | |
|
5f06541060 |
|
@ -116,7 +116,7 @@ pip install -e .
|
|||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
|
|
|
@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
|
|
|
@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
|
|||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
|
@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
|
|||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
|
@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
|||
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
|
@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
|
|||
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
|
|||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
|
|
@ -111,6 +111,32 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by DexVLA to train Stage2"""
|
||||
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def constant_schedule(current_step):
|
||||
return 1
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return constant_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
<h1 align="center">
|
||||
DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning</h1>
|
||||
|
||||
This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA).
|
||||
This is [project website](https://dex-vla.github.io/).
|
||||
|
||||
## Dataset
|
||||
### Data format
|
||||
DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras.
|
||||
|
||||
⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations.
|
||||
So you have to <font color='red'>add sub-step reasoning in your data for training</font>.
|
||||
|
||||
Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation.
|
||||
For example, if the episode is 10 steps. The length of this list should be 10 as well. And it may looks like:
|
||||
~~~python
|
||||
reasoning = [
|
||||
"This is step 1.",
|
||||
"This is step 1.",
|
||||
"This is step 2.",
|
||||
"This is step 2.",
|
||||
...
|
||||
"This is step 4.",
|
||||
]
|
||||
~~~
|
||||
|
||||
Besides, your data should include another key ``action_is_pad`` which is a bool mask indicating whether this action chunk is padded.
|
||||
Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5.
|
||||
And the mask looks like:
|
||||
~~~python
|
||||
The 6th chunk: [false, false, false, false, true]
|
||||
The 7th chunk: [false, false, false, true, true]
|
||||
The 8th chunk: [false, false, true, true, true]
|
||||
The 9th chunk: [false, true, true, true, true]
|
||||
~~~
|
||||
|
||||
### Training Data for DexVLA
|
||||
The pretraining dataset comprises approximately 100 hours of collected data by ourselves. The dataset mainly including four embodiments which are: moblie Agilex Aloha, single Franka Emika and single UR5e.
|
||||
We haven't use any public dataset such as Open-X or DROID.
|
||||
|
||||
## 🤗Download Pretrained Weights
|
||||
### Download official Qwen2_VL weights
|
||||
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
|
||||
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
|
||||
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
|
||||
in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link:
|
||||
|
||||
| Model | Link |
|
||||
|---------------------|----------------------------------------------------------------|
|
||||
| Qwen2-VL (~2B) | [huggingface](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) |
|
||||
|
||||
**❗❗** After downloading the standard weights, you have to replace the official "config.json"
|
||||
with our ["config.json"](https://github.com/juruobenruo/DexVLA/blob/main/docs/config.json) designed for VLA.
|
||||
### Download our pretrained ScaleDP-H weights(Stage 1)
|
||||
We released our pretrained weights of ScaleDP-H which is trained after Stage1. Now you can download the weights and directly finetuning your data on Stage 2.
|
||||
|
||||
| Model | Link |
|
||||
|-------------------|----------------------------------------------------------------|
|
||||
| ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) |
|
||||
| ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) |
|
||||
|
||||
**❗❗**After downloading the weights, you have to transform it into ``safetensors`` format, you can simply run this code:
|
||||
~~~python
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
path = "/path/to/open_scale_dp_l_backbone.ckpt"
|
||||
checkpoint = torch.load(path, map_location=torch.device('cpu'))['nets']['nets']
|
||||
|
||||
# Save the weights in safetensors format
|
||||
safetensors_path = "/path/to/open_scale_dp_l_backbone.safetensors"
|
||||
save_file(checkpoint, safetensors_path)
|
||||
print(f"Converted {path} to {safetensors_path}")
|
||||
pass
|
||||
|
||||
~~~
|
||||
|
||||
## 🦾Train
|
||||
We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3.
|
||||
|
||||
### Training Stage 2
|
||||
~~~shell
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type dexvla \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 2 \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Stage 3
|
||||
Stage3 can be viewed as continual training on specific dexterous tasks like laundry folding which is same as PI0. So stage3 is trained based on stage2.
|
||||
~~~shell
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type dexvla \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--.pretrained_path /path/to/pretrained/stage2/weights \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 3 \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--batch_size 2 \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Time
|
||||
Original DexVLA is trained on 8 x H100 GPUs. And the training time for each stage is listed as follows:
|
||||
|
||||
| Stage | Batch Size(each gpu) | Steps | Time(hour) |
|
||||
|--------|----------------------|--------|------------|
|
||||
| Stage1 | 32 | 60000 | 30 |
|
||||
| Stage2 | 12 | 100000 | 30 |
|
||||
| Stage3 | 12 | 60000 | 18 |
|
||||
|
||||
|
||||
## Evaluation
|
||||
### Evaluation Script
|
||||
You can evaluate dexvla by following scripts.
|
||||
~~~shell
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.type dexvla \
|
||||
--policy.pretrained_path /path/to/pretrained/stage2/or/stage3/weights \
|
||||
--env.type aloha \
|
||||
--env.episode_length 5 \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--env.task AlohaInsertion-v0 \
|
||||
--eval.n_episodes 1 \
|
||||
--eval.batch_size 1
|
||||
~~~
|
||||
|
||||
### Inference Speed
|
||||
Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz.
|
|
@ -0,0 +1,179 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Qwen2VL model configuration"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
ConstantWithWarmupSchedulerConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
register_policy_heads()
|
||||
register_qwen2_vla()
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dexvla")
|
||||
@dataclass
|
||||
class DexVLAConfig(PreTrainedConfig):
|
||||
# For loading policy head
|
||||
policy_head_type: str = "scale_dp_policy"
|
||||
policy_head_size: str = "scaledp_l"
|
||||
action_dim: int = 14
|
||||
state_dim: int = 14
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
n_obs_steps: int = 1
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
hidden_size: int = 1536
|
||||
qwen2_vl_path: str = (
|
||||
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
|
||||
)
|
||||
|
||||
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
|
||||
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
|
||||
|
||||
training_stage: int = 2 # specific training stage, [2, 3]
|
||||
using_film: bool = True
|
||||
llm_loss_weight: float = 1.0
|
||||
with_llm_head: bool = True
|
||||
using_reasoning: bool = True
|
||||
resize_size: tuple = (240, 320)
|
||||
# Training presets
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: Tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 2_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
# "VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.using_reasoning:
|
||||
assert self.using_film, "using_reasoning requires `using_film=True`"
|
||||
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
|
||||
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
|
||||
else:
|
||||
print(
|
||||
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
|
||||
)
|
||||
|
||||
if self.qwen2_vl_path is None:
|
||||
raise ValueError(
|
||||
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
|
||||
)
|
||||
|
||||
if self.policy_head_type == "scale_dp_policy":
|
||||
self.policy_head_config = AutoConfig.for_model(
|
||||
model_type=self.policy_head_type,
|
||||
model_size=self.policy_head_size,
|
||||
cond_dim=self.hidden_size,
|
||||
action_dim=self.action_dim,
|
||||
prediction_horizon=self.chunk_size,
|
||||
state_dim=self.state_dim,
|
||||
)
|
||||
elif self.policy_head_type == "unet_diffusion":
|
||||
self.policy_head_config = AutoConfig.for_model(
|
||||
model_type=self.policy_head_type,
|
||||
global_cond_dim=self.hidden_size,
|
||||
action_dim=self.action_dim,
|
||||
state_dim=self.state_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
|
||||
|
||||
if self.training_stage not in [2, 3]:
|
||||
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
|
||||
|
||||
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
# for i in range(self.empty_cameras):
|
||||
# key = f"observation.images.empty_camera_{i}"
|
||||
# empty_camera = PolicyFeature(
|
||||
# type=FeatureType.VISUAL,
|
||||
# shape=(3, 480, 640),
|
||||
# )
|
||||
# self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
if self.training_stage == 3:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
else:
|
||||
return ConstantWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ActionProjector(nn.Module):
|
||||
def __init__(self, in_dim, out_dim=1024):
|
||||
super().__init__()
|
||||
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||
self.mlps = nn.ModuleList(
|
||||
[
|
||||
# nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, in_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.Dropout(0.0),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_1d_pool(x.permute(1, 0)).permute(1, 0)
|
||||
for mlp in self.mlps:
|
||||
x = mlp(x)
|
||||
return x
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, feature_dim, condition_dim):
|
||||
super().__init__()
|
||||
self.scale_fc = nn.Linear(condition_dim, feature_dim)
|
||||
self.shift_fc = nn.Linear(condition_dim, feature_dim)
|
||||
|
||||
nn.init.zeros_(self.scale_fc.weight)
|
||||
nn.init.zeros_(self.scale_fc.bias)
|
||||
nn.init.zeros_(self.shift_fc.weight)
|
||||
nn.init.zeros_(self.shift_fc.bias)
|
||||
|
||||
def forward(self, x, condition):
|
||||
# calculate scale and shift
|
||||
scale = self.scale_fc(condition)
|
||||
shift = self.shift_fc(condition)
|
||||
|
||||
# film
|
||||
return x * (1 + scale) + shift
|
|
@ -0,0 +1,313 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from safetensors.torch import load_file
|
||||
from torch import Tensor
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||
|
||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class DexVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = DexVLAConfig
|
||||
name = "dexvla"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DexVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]:
|
||||
setattr(config.qwen2_vla_config, k, config.__dict__[k])
|
||||
|
||||
# if self.config.training_stage == 2:
|
||||
# self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16)
|
||||
model_base = self.config.qwen2_vl_path
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_base,
|
||||
config=config.qwen2_vla_config,
|
||||
trust_remote_code=True,
|
||||
_fast_init=False,
|
||||
# attn_implementation="flash_attention_2",
|
||||
).to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if self.config.pretrained_scaledp_path is not None:
|
||||
print(
|
||||
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
||||
)
|
||||
pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path)
|
||||
|
||||
keys_to_del_dit = []
|
||||
pretrain_scaledp_weights = {
|
||||
k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
for k in pretrain_scaledp_weights:
|
||||
if "noise_pred" not in k: # del weights of vision backbones
|
||||
keys_to_del_dit.append(k)
|
||||
if "cond_obs_emb" in k:
|
||||
keys_to_del_dit.append(k)
|
||||
for k in keys_to_del_dit:
|
||||
del pretrain_scaledp_weights[k]
|
||||
pretrain_scaledp_weights = {
|
||||
k[15:] if k.startswith("noise_pred_net.") else k: v
|
||||
for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
|
||||
self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)
|
||||
|
||||
self.model.requires_grad_(False)
|
||||
self.model.policy_head.requires_grad_(True)
|
||||
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path)
|
||||
self.vla_processor = Qwen2VLAProcess(
|
||||
tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor
|
||||
) # process the input data into VLM format
|
||||
|
||||
self.resize_size = self.config.resize_size
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.Resize(size=self.resize_size, antialias=True),
|
||||
transforms.RandomCrop(size=[int(self.resize_size[0] * ratio), int(self.resize_size[1] * ratio)]),
|
||||
transforms.Resize(self.resize_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), # , hue=0.08)
|
||||
]
|
||||
|
||||
self.reset()
|
||||
|
||||
def process_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Applying DexVLA preprocessing to original data. Including resizing images. Scaling the range of actions, states."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
task_descs = batch["task"]
|
||||
try:
|
||||
reasonings = batch["reasoning"]
|
||||
except KeyError:
|
||||
reasonings = ["None."] * len(task_descs)
|
||||
|
||||
pass
|
||||
is_pad = batch["action_is_pad"]
|
||||
all_cam_images = []
|
||||
for k in present_img_keys:
|
||||
all_cam_images.append(batch[k])
|
||||
|
||||
# construct observations, and scale 0-1 to 0-255
|
||||
image_data = torch.stack(all_cam_images) * 255
|
||||
image_data = image_data.to(dtype=torch.uint8)
|
||||
# construct observations
|
||||
qpos_data = batch["observation.state"].float()
|
||||
action_data = batch["action"].float()
|
||||
|
||||
orig_shape = image_data.shape
|
||||
image_data = image_data.view(-1, *orig_shape[2:])
|
||||
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
|
||||
|
||||
vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings}
|
||||
# processing vl_data into qwen2_vl format
|
||||
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
|
||||
vla_inputs["states"] = qpos_data
|
||||
vla_inputs["is_pad"] = is_pad
|
||||
vla_inputs["actions"] = action_data
|
||||
return vla_inputs
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
processed_batch = self.process_batch(batch)
|
||||
|
||||
ret = self.model.forward(**processed_batch)
|
||||
loss_dict = ret["loss"]
|
||||
loss = loss_dict["loss"].mean()
|
||||
return loss, loss_dict
|
||||
|
||||
def dexvla_predict_action(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
actions=None,
|
||||
states=None,
|
||||
is_pad=None,
|
||||
tokenizer=None,
|
||||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
temperature=0.2,
|
||||
max_new_tokens=256,
|
||||
eos_token_id=tokenizer.eos_token_id, # End of sequence token
|
||||
pad_token_id=tokenizer.eos_token_id, # Pad token
|
||||
use_cache=True,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
output_ids = outputs.sequences
|
||||
# last_hidden_states = outputs.hidden_states[-2][-1]
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
|
||||
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
|
||||
|
||||
outputs_text = outputs_text.strip()
|
||||
last_hidden_states = [each[-1] for each in outputs.hidden_states] # all hidden states
|
||||
all_hidden_states = torch.cat(last_hidden_states, dim=1)
|
||||
|
||||
action_hidden_states = None
|
||||
labels_input = torch.ones((1, input_token_len)) * -100
|
||||
labels_output = torch.ones((1, output_ids.shape[1] - input_token_len))
|
||||
labels = torch.cat([labels_input, labels_output], dim=1)
|
||||
|
||||
if self.model.using_film:
|
||||
action_hidden_states = self.model.film_forward(
|
||||
labels=labels,
|
||||
input_ids=output_ids,
|
||||
hidden_states=torch.cat(last_hidden_states, dim=1),
|
||||
)
|
||||
|
||||
action = self.model.policy_head(
|
||||
actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||
)
|
||||
return action, outputs_text
|
||||
|
||||
def tinyvla_predict_action(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
actions=None,
|
||||
states=None,
|
||||
is_pad=None,
|
||||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
all_hidden_states = self.model.forward(
|
||||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
tinyvla=True,
|
||||
)
|
||||
|
||||
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
|
||||
|
||||
action = self.model.policy_head(
|
||||
actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||
)
|
||||
return action, "tinyvla generates no reasoning"
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
try:
|
||||
task_descs = batch["task"]
|
||||
except KeyError:
|
||||
task_descs = " "
|
||||
print("No task descriptions found for this task")
|
||||
|
||||
all_cam_images = []
|
||||
for k in present_img_keys:
|
||||
all_cam_images.append(batch[k])
|
||||
|
||||
# construct observations, and scale 0-1 to 0-255
|
||||
image_data = torch.stack(all_cam_images) * 255
|
||||
image_data = image_data.to(dtype=torch.uint8)
|
||||
# construct observations
|
||||
qpos_data = batch["observation.state"].float()
|
||||
|
||||
image_data = image_data.squeeze(0)
|
||||
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# processing vl_data into qwen2_vl format
|
||||
vla_inputs = self.vla_processor.single_forward_process(
|
||||
images=image_data, raw_lang=task_descs, reasoning=None, eval=True
|
||||
)
|
||||
vla_inputs["states"] = qpos_data
|
||||
|
||||
if self.config.using_film and self.config.with_llm_head: # dexvla
|
||||
all_actions, outputs = self.dexvla_predict_action(
|
||||
**vla_inputs, is_eval=True, tokenizer=self.tokenizer
|
||||
)
|
||||
else: # tinyvla
|
||||
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
|
||||
|
||||
actions = self.unnormalize_outputs({"action": all_actions})["action"]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
from .modeling_scaledp import ScaleDP
|
||||
from .modeling_unet_diffusion import ConditionalUnet1D
|
||||
|
||||
|
||||
def register_policy_heads():
|
||||
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
||||
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
||||
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
||||
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MODEL_STRUCTURE = {
|
||||
"scaledp_h": {
|
||||
"depth": 32,
|
||||
"n_emb": 1280,
|
||||
"num_heads": 16,
|
||||
},
|
||||
"scaledp_l": {
|
||||
"depth": 24,
|
||||
"n_emb": 1024,
|
||||
"num_heads": 16,
|
||||
}, # 400M
|
||||
}
|
||||
|
||||
|
||||
class ScaleDPPolicyConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration for ScaleDP policy head
|
||||
"""
|
||||
|
||||
model_type = "scale_dp_policy"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eval: bool = False,
|
||||
action_dim: int = 14, # action dim
|
||||
# output_dim: int = 14, # action dim
|
||||
cond_dim: int = 1536, # the input dim of the condition
|
||||
state_dim: int = 14, # the input dim of the state
|
||||
prediction_horizon: int = 16, # horizon
|
||||
n_obs_steps: int = 2, # number of observation steps
|
||||
depth: int = 28, # number of DiT blocks
|
||||
n_emb: int = 256, # embedding size
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4.0,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = True,
|
||||
learn_sigma: bool = False,
|
||||
model_size: str = "none",
|
||||
num_inference_timesteps: int = 10,
|
||||
noise_samples: int = 1,
|
||||
num_train_timesteps: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
if model_size != "none":
|
||||
depth = MODEL_STRUCTURE[model_size]["depth"]
|
||||
n_emb = MODEL_STRUCTURE[model_size]["n_emb"]
|
||||
num_heads = MODEL_STRUCTURE[model_size]["num_heads"]
|
||||
else:
|
||||
# raise ValueError("model_size show not be 'none'")
|
||||
pass
|
||||
# print("model_size should not be 'none'")
|
||||
self.eval = eval
|
||||
|
||||
self.input_dim = action_dim
|
||||
self.output_dim = action_dim
|
||||
self.prediction_horizon = prediction_horizon
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
self.state_dim = state_dim
|
||||
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.depth = depth
|
||||
self.n_emb = n_emb
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.learn_sigma = learn_sigma
|
||||
|
||||
self.num_inference_timesteps = num_inference_timesteps
|
||||
self.num_queries = prediction_horizon
|
||||
self.noise_samples = noise_samples
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from CLIPConfig
|
||||
if config_dict.get("model_type") == "llava_pythia":
|
||||
config_dict = config_dict["action_head"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration for dit diffusion policy head
|
||||
"""
|
||||
|
||||
model_type = "unet_diffusion_policy"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_dim=10,
|
||||
global_cond_dim=2048,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
state_dim=7,
|
||||
prediction_horizon=16,
|
||||
noise_samples=1,
|
||||
num_inference_timesteps=10,
|
||||
num_train_timesteps=100,
|
||||
**kwargs,
|
||||
):
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
self.input_dim = action_dim
|
||||
self.noise_samples = noise_samples
|
||||
self.prediction_horizon = prediction_horizon
|
||||
self.num_inference_timesteps = num_inference_timesteps
|
||||
self.global_cond_dim = global_cond_dim
|
||||
self.diffusion_step_embed_dim = diffusion_step_embed_dim
|
||||
self.down_dims = down_dims
|
||||
self.kernel_size = kernel_size
|
||||
self.n_groups = n_groups
|
||||
self.state_dim = state_dim
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from CLIPConfig
|
||||
if config_dict.get("model_type") == "llava_pythia":
|
||||
config_dict = config_dict["action_head"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
|
@ -0,0 +1,561 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as func
|
||||
import torch.utils.checkpoint
|
||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||
from torch.jit import Final
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
||||
b, n, c = x.shape
|
||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = func.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
# attn = q @ k.transpose(-2, -1)
|
||||
# if attn_mask is not None:
|
||||
# attn += attn_mask
|
||||
# attn = attn.softmax(dim=-1)
|
||||
# attn = self.attn_drop(attn)
|
||||
# x = attn @ v
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1))
|
||||
|
||||
# Add attention mask if provided
|
||||
if attn_mask is not None:
|
||||
attn_scores += attn_mask
|
||||
|
||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||
attn_weights = func.softmax(attn_scores, dim=-1)
|
||||
|
||||
# Dropout on attention weights (if dropout is used)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
|
||||
# Apply attention weights to value tensor (V)
|
||||
x = torch.matmul(attn_weights, v)
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.bfloat16) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core ScaleDP Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class ScaleDPBlock(nn.Module):
|
||||
"""
|
||||
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
def approx_gelu():
|
||||
return nn.GELU(approximate="tanh")
|
||||
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
|
||||
6, dim=1
|
||||
)
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn(
|
||||
modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
|
||||
) # norm, scale&shift, attn, scale,
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of ScaleDP.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, output_dim):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ScaleDP(PreTrainedModel):
|
||||
"""
|
||||
Diffusion models with a Transformer backbone.
|
||||
"""
|
||||
|
||||
config_class = ScaleDPPolicyConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ScaleDPPolicyConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
# compute number of tokens for main trunk and conScaleDPion encoder
|
||||
if config.n_obs_steps is None:
|
||||
config.n_obs_steps = config.prediction_horizon
|
||||
t = config.prediction_horizon
|
||||
t_cond = 1
|
||||
if not config.time_as_cond:
|
||||
t += 1
|
||||
t_cond -= 1
|
||||
obs_as_cond = config.cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert config.time_as_cond
|
||||
t_cond += config.n_obs_steps
|
||||
|
||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||
self.combine = nn.Sequential(
|
||||
nn.Linear(config.cond_dim + config.state_dim, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, config.cond_dim),
|
||||
)
|
||||
self.learn_sigma = config.learn_sigma
|
||||
self.input_dim = config.input_dim
|
||||
self.output_dim = config.output_dim * 2 if config.learn_sigma else config.output_dim
|
||||
self.num_heads = config.num_heads
|
||||
|
||||
self.x_embedder = nn.Linear(config.input_dim, config.n_emb)
|
||||
self.t_embedder = TimestepEmbedder(config.n_emb)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(config.cond_dim, config.n_emb)
|
||||
|
||||
# Will use fixed sin-cos embedding:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio)
|
||||
for _ in range(config.depth)
|
||||
]
|
||||
)
|
||||
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
||||
# self.initialize_weights()
|
||||
# constants
|
||||
self.t = t
|
||||
self.t_cond = t_cond
|
||||
self.prediction_horizon = config.prediction_horizon
|
||||
self.time_as_cond = config.time_as_cond
|
||||
self.action_dim = config.output_dim
|
||||
self.obs_as_cond = obs_as_cond
|
||||
logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
# self.proj_to_action = nn.Identity()
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=config.num_train_timesteps, # 100
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
self.num_queries = config.num_queries # 16
|
||||
self.noise_samples = config.noise_samples # 1
|
||||
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
nn.init.normal_(self.pos_embed, mean=0.0, std=0.02)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.cond_obs_emb.weight, mean=0.0, std=0.02)
|
||||
nn.init.constant_(self.cond_obs_emb.bias, 0)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in ScaleDP blocks:
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the models into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _p in m.named_parameters():
|
||||
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith("bias"):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith("bias"):
|
||||
# MultiheadAttention bias starts with "bias"
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
)
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
b = actions.size(0)
|
||||
actions = actions[:, : self.num_queries]
|
||||
is_pad = is_pad[:, : self.num_queries]
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, b, Ta, D(1, 2, 16, 14)
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||
# (this is the forward diffusion process)
|
||||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
|
||||
timesteps = timesteps.repeat(num_noise_samples)
|
||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||
states = states.repeat(num_noise_samples, 1)
|
||||
|
||||
noise_pred = self.model_forward(
|
||||
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||
)
|
||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
# loss_dict['loss'] = loss
|
||||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = self.action_dim
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# predict noise
|
||||
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
|
||||
|
||||
# inverse diffusion step (remove noise)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def model_forward(self, x, t, global_cond, states):
|
||||
"""
|
||||
Forward pass of ScaleDP.
|
||||
x: (N, T, input_dim) noisy actions
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings
|
||||
"""
|
||||
global_cond = global_cond.squeeze(1)
|
||||
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
||||
global_cond = self.combine(global_cond)
|
||||
|
||||
if not torch.is_tensor(t):
|
||||
t = torch.tensor([t], dtype=torch.long, device=x.device)
|
||||
elif torch.is_tensor(t) and len(t.shape) == 0:
|
||||
t = t[None].to(x.device)
|
||||
t = t.expand(t.shape[0])
|
||||
|
||||
x = self.x_embedder(x) + self.pos_embed.to(
|
||||
device=x.device, dtype=x.dtype
|
||||
) # (N, T, D), where T = prediction_horizon
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
if self.obs_as_cond:
|
||||
global_cond = self.cond_obs_emb(global_cond) # (N, D)
|
||||
# c = t + global_cond.sum(dim=1) # (N, D)
|
||||
c = t + global_cond # (N, D)
|
||||
for block in self.blocks:
|
||||
# x = block(x, c, attn_mask=self.mask) # (N, T, D)
|
||||
x = block(x, c, attn_mask=None) # (N, T, D)
|
||||
x = self.final_layer(x, c) # (N, T, output_dim)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ScaleDP Configs #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def scaledp_h(**kwargs):
|
||||
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
|
||||
|
||||
|
||||
def scaledp_l(**kwargs):
|
||||
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
|
@ -0,0 +1,387 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# requires diffusers==0.11.1
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
|
||||
# =================== UNet for Diffusion ==============
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim, dtype):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels * 2
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(PreTrainedModel):
|
||||
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
|
||||
|
||||
config_class = UnetDiffusionPolicyConfig
|
||||
|
||||
def __init__(self, config: UnetDiffusionPolicyConfig):
|
||||
"""
|
||||
input_dim: Dim of actions.
|
||||
global_cond_dim: Dim of global conditioning applied with FiLM
|
||||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
||||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
||||
down_dims: Channel size for each UNet level.
|
||||
The length of this array determines number of levels.
|
||||
kernel_size: Conv kernel size
|
||||
n_groups: Number of groups for GroupNorm
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
all_dims = [config.input_dim] + list(config.down_dims)
|
||||
start_dim = config.down_dims[0]
|
||||
|
||||
self.num_queries = config.prediction_horizon
|
||||
self.noise_samples = config.noise_samples
|
||||
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
|
||||
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
|
||||
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
|
||||
dsed = config.diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed, torch.bfloat16),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed + config.global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
|
||||
nn.Conv1d(start_dim, config.input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
|
||||
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
# self.proj_to_action = nn.Identity()
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=config.num_train_timesteps, # 100
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
||||
|
||||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
b = actions.size(0)
|
||||
actions = copy.deepcopy(actions[:, : self.num_queries])
|
||||
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, b, Ta, D
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||
# (this is the forward diffusion process)
|
||||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
|
||||
timesteps = timesteps.repeat(num_noise_samples)
|
||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||
states = states.repeat(num_noise_samples, 1)
|
||||
|
||||
noise_pred = self.model_forward(
|
||||
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||
)
|
||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
# loss_dict['loss'] = loss
|
||||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = 14
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# predict noise
|
||||
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
|
||||
|
||||
# inverse diffusion step (remove noise)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def model_forward(
|
||||
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
|
||||
):
|
||||
"""
|
||||
x: (b,T,input_dim)
|
||||
timestep: (b,) or int, diffusion step
|
||||
global_cond: (b,global_cond_dim)
|
||||
output: (b,T,input_dim)
|
||||
"""
|
||||
# (b,t,c)
|
||||
sample = sample.moveaxis(-1, -2)
|
||||
# (b,c,t)
|
||||
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
global_cond = self.norm_after_pool(global_cond)
|
||||
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
||||
global_cond = self.combine(global_cond)
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
# (b,c,t)
|
||||
x = x.moveaxis(-1, -2)
|
||||
# (b,t,c)
|
||||
return x
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from .configuration_qwen2_vla import Qwen2VLAConfig
|
||||
from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
|
||||
|
||||
|
||||
def register_qwen2_vla():
|
||||
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
|
@ -0,0 +1,254 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
embed_dim=1280,
|
||||
hidden_size=3584,
|
||||
hidden_act="quick_gelu",
|
||||
mlp_ratio=4,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "qwen2_vl":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class Qwen2VLAConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2VL style configuration
|
||||
>>> configuration = Qwen2VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_vla"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
# For loading policy head
|
||||
policy_head_type="scale_dp_policy", # unet_diffusion_policy
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = Qwen2VLVisionConfig()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
self.policy_head_type = policy_head_type # for loading policy head
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,172 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import fetch_image
|
||||
|
||||
|
||||
class Qwen2VLAProcess:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer=None,
|
||||
max_seq_len=512,
|
||||
multimodal_processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
self.multimodal_processor = multimodal_processor
|
||||
|
||||
def qwen2_image_preprocess(self, each):
|
||||
ele = {}
|
||||
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
||||
ele["image"] = each
|
||||
|
||||
ele["resized_height"] = each.height
|
||||
ele["resized_width"] = each.width
|
||||
each = fetch_image(ele)
|
||||
return torch.from_numpy(np.array(each))
|
||||
|
||||
def single_forward_process(self, images, raw_lang, reasoning, eval=False, use_reasoning=True):
|
||||
len_views = images.shape[0]
|
||||
messages = self.construct_chat_data(len_views, raw_lang)
|
||||
|
||||
data_dict = {"messages": messages}
|
||||
|
||||
image_data = torch.chunk(images, len_views, 0)
|
||||
|
||||
images_list = []
|
||||
|
||||
for _i, each in enumerate(image_data):
|
||||
img_pil = self.qwen2_image_preprocess(each)
|
||||
images_list.append(img_pil)
|
||||
|
||||
text = self.multimodal_processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
model_inputs = self.multimodal_processor(
|
||||
text=text,
|
||||
images=images_list,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if eval:
|
||||
new_dict = {}
|
||||
for k, v in model_inputs.items():
|
||||
if "image_grid" in k:
|
||||
new_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
new_dict[k] = v
|
||||
return new_dict
|
||||
|
||||
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||
answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||
|
||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||
output_labels = output_text["input_ids"]
|
||||
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1)
|
||||
model_inputs["attention_mask"] = torch.cat(
|
||||
(model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1
|
||||
)
|
||||
labels = torch.cat((input_labels, output_labels), dim=-1)
|
||||
|
||||
data_dict["labels"] = labels
|
||||
for k, v in model_inputs.items():
|
||||
if "image_grid" in k:
|
||||
data_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
data_dict[k] = v
|
||||
return data_dict
|
||||
|
||||
def forward(self, batch, use_reasoning=True):
|
||||
"""This is the main process function for processing vl data into Qwen2_vl format"""
|
||||
all_images = batch["images"]
|
||||
all_images = torch.einsum(
|
||||
"v b c h w -> b v c h w", all_images
|
||||
) # camera_views, batch_size, channel, height, width
|
||||
|
||||
ret_l = []
|
||||
|
||||
for idx, images in enumerate(all_images):
|
||||
raw_lang = batch["raw_langs"][idx]
|
||||
reasoning = batch["reasonings"][idx]
|
||||
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
|
||||
ret_l.append(ret_dict)
|
||||
|
||||
return self.post_process(ret_l)
|
||||
|
||||
def post_process(self, instances):
|
||||
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
||||
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
||||
|
||||
image_grid_spatiotemporal = torch.stack(
|
||||
[instances["image_grid_spatiotemporal"] for instances in instances]
|
||||
)
|
||||
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
||||
pixel_values_videos = None
|
||||
video_grid_spatiotemporal = None
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
||||
labels = torch.flip(labels, dims=[1])
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
input_ids = torch.flip(input_ids, dims=[1])
|
||||
b = input_ids.shape[0]
|
||||
|
||||
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
|
||||
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
|
||||
)
|
||||
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
||||
|
||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask[0],
|
||||
"labels": labels,
|
||||
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
def construct_chat_data(self, len_image, raw_lang):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [],
|
||||
},
|
||||
]
|
||||
|
||||
for _i in range(len_image):
|
||||
messages[0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": None,
|
||||
}
|
||||
)
|
||||
messages[0]["content"].append({"type": "text", "text": ""})
|
||||
messages[0]["content"][-1]["text"] = raw_lang
|
||||
|
||||
return messages
|
|
@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
|
|||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
|
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "dexvla":
|
||||
from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy
|
||||
|
||||
return DexVLAPolicy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
|
@ -74,6 +79,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "dexvla":
|
||||
return DexVLAConfig(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
else:
|
||||
|
|
|
@ -512,13 +512,13 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=640,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=480,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
|
|
@ -492,13 +492,13 @@ if __name__ == "__main__":
|
|||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
|
|
@ -174,7 +174,10 @@ def run_server(
|
|||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
{
|
||||
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
|
@ -381,7 +384,7 @@ def visualize_dataset_html(
|
|||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
|
|
@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
|||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
dexvla = ["transformers>=4.45.2", "qwen_vl_utils==0.0.10", "timm==0.9.10"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
|
Loading…
Reference in New Issue