This commit is contained in:
jayce-wen 2025-04-10 10:15:55 +08:00 committed by GitHub
commit 5d4d9df58f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 4408 additions and 0 deletions

View File

@ -111,6 +111,32 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by DexVLA to train Stage2"""
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def constant_schedule(current_step):
return 1
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)
return constant_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)

View File

@ -13,6 +13,7 @@
# limitations under the License.
from .act.configuration_act import ACTConfig as ACTConfig
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig

View File

@ -0,0 +1,140 @@
<h1 align="center">
DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning</h1>
This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA).
This is [project website](https://dex-vla.github.io/).
## Dataset
### Data format
DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras.
⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations.
So you have to <font color='red'>add sub-step reasoning in your data for training</font>.
Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation.
For example, if the episode is 10 steps. The length of this list should be 10 as well. And it may looks like:
~~~python
reasoning = [
"This is step 1.",
"This is step 1.",
"This is step 2.",
"This is step 2.",
...
"This is step 4.",
]
~~~
Besides, your data should include another key ``action_is_pad`` which is a bool mask indicating whether this action chunk is padded.
Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5.
And the mask looks like:
~~~python
The 6th chunk: [false, false, false, false, true]
The 7th chunk: [false, false, false, true, true]
The 8th chunk: [false, false, true, true, true]
The 9th chunk: [false, true, true, true, true]
~~~
### Training Data for DexVLA
The pretraining dataset comprises approximately 100 hours of collected data by ourselves. The dataset mainly including four embodiments which are: moblie Agilex Aloha, single Franka Emika and single UR5e.
We haven't use any public dataset such as Open-X or DROID.
## 🤗Download Pretrained Weights
### Download official Qwen2_VL weights
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link:
| Model | Link |
|---------------------|----------------------------------------------------------------|
| Qwen2-VL (~2B) | [huggingface](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) |
**❗❗** After downloading the standard weights, you have to replace the official "config.json"
with our ["config.json"](https://github.com/juruobenruo/DexVLA/blob/main/docs/config.json) designed for VLA.
### Download our pretrained ScaleDP-H weights(Stage 1)
We released our pretrained weights of ScaleDP-H which is trained after Stage1. Now you can download the weights and directly finetuning your data on Stage 2.
| Model | Link |
|-------------------|----------------------------------------------------------------|
| ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) |
| ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) |
**❗❗**After downloading the weights, you have to transform it into ``safetensors`` format, you can simply run this code:
~~~python
import torch
from safetensors.torch import save_file
path = "/path/to/open_scale_dp_l_backbone.ckpt"
checkpoint = torch.load(path, map_location=torch.device('cpu'))['nets']['nets']
# Save the weights in safetensors format
safetensors_path = "/path/to/open_scale_dp_l_backbone.safetensors"
save_file(checkpoint, safetensors_path)
print(f"Converted {path} to {safetensors_path}")
pass
~~~
## 🦾Train
We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3.
### Training Stage 2
~~~shell
python lerobot/scripts/train.py \
--policy.type dexvla \
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \
--policy.policy_head_size 'scaledp_h' \
--policy.training_stage 2 \
--dataset.repo_i lerobot/aloha_mobile_chair \
--policy.using_film true \
--output_dir /path/to/output \
--steps 10000 \
--save_freq 1000 \
--optimizer_lr 2e-5
~~~
### Training Stage 3
Stage3 can be viewed as continual training on specific dexterous tasks like laundry folding which is same as PI0. So stage3 is trained based on stage2.
~~~shell
python lerobot/scripts/train.py \
--policy.type dexvla \
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
--.pretrained_path /path/to/pretrained/stage2/weights \
--policy.policy_head_size 'scaledp_h' \
--policy.training_stage 3 \
--dataset.repo_i lerobot/aloha_mobile_chair \
--batch_size 2 \
--policy.using_film true \
--output_dir /path/to/output \
--steps 10000 \
--save_freq 1000 \
--optimizer_lr 2e-5
~~~
### Training Time
Original DexVLA is trained on 8 x H100 GPUs. And the training time for each stage is listed as follows:
| Stage | Batch Size(each gpu) | Steps | Time(hour) |
|--------|----------------------|--------|------------|
| Stage1 | 32 | 60000 | 30 |
| Stage2 | 12 | 100000 | 30 |
| Stage3 | 12 | 60000 | 18 |
## Evaluation
### Evaluation Script
You can evaluate dexvla by following scripts.
~~~shell
python lerobot/scripts/eval.py \
--policy.type dexvla \
--policy.pretrained_path /path/to/pretrained/stage2/or/stage3/weights \
--env.type aloha \
--env.episode_length 5 \
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
--env.task AlohaInsertion-v0 \
--eval.n_episodes 1 \
--eval.batch_size 1
~~~
### Inference Speed
Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz.

View File

@ -0,0 +1,179 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
from dataclasses import dataclass, field
from typing import Tuple
from transformers import AutoConfig
from transformers.utils import logging
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
ConstantWithWarmupSchedulerConfig,
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
logger = logging.get_logger(__name__)
register_policy_heads()
register_qwen2_vla()
@PreTrainedConfig.register_subclass("dexvla")
@dataclass
class DexVLAConfig(PreTrainedConfig):
# For loading policy head
policy_head_type: str = "scale_dp_policy"
policy_head_size: str = "scaledp_l"
action_dim: int = 14
state_dim: int = 14
chunk_size: int = 50
n_action_steps: int = 50
n_obs_steps: int = 1
device: str = "cuda"
hidden_size: int = 1536
qwen2_vl_path: str = (
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
)
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
training_stage: int = 2 # specific training stage, [2, 3]
using_film: bool = True
llm_loss_weight: float = 1.0
with_llm_head: bool = True
using_reasoning: bool = True
resize_size: tuple = (240, 320)
# Training presets
optimizer_lr: float = 2e-5
optimizer_betas: Tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 2_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
# "VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
def __post_init__(self):
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.using_reasoning:
assert self.using_film, "using_reasoning requires `using_film=True`"
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
else:
print(
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
)
if self.qwen2_vl_path is None:
raise ValueError(
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
)
if self.policy_head_type == "scale_dp_policy":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
model_size=self.policy_head_size,
cond_dim=self.hidden_size,
action_dim=self.action_dim,
prediction_horizon=self.chunk_size,
state_dim=self.state_dim,
)
elif self.policy_head_type == "unet_diffusion":
self.policy_head_config = AutoConfig.for_model(
model_type=self.policy_head_type,
global_cond_dim=self.hidden_size,
action_dim=self.action_dim,
state_dim=self.state_dim,
)
else:
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
if self.training_stage not in [2, 3]:
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
def validate_features(self) -> None:
# TODO: implement value error
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
# for i in range(self.empty_cameras):
# key = f"observation.images.empty_camera_{i}"
# empty_camera = PolicyFeature(
# type=FeatureType.VISUAL,
# shape=(3, 480, 640),
# )
# self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self):
if self.training_stage == 3:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
else:
return ConstantWithWarmupSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@ -0,0 +1,58 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
class ActionProjector(nn.Module):
def __init__(self, in_dim, out_dim=1024):
super().__init__()
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
self.mlps = nn.ModuleList(
[
# nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.Dropout(0.0),
]
)
def forward(self, x):
x = self.global_1d_pool(x.permute(1, 0)).permute(1, 0)
for mlp in self.mlps:
x = mlp(x)
return x
class FiLM(nn.Module):
def __init__(self, feature_dim, condition_dim):
super().__init__()
self.scale_fc = nn.Linear(condition_dim, feature_dim)
self.shift_fc = nn.Linear(condition_dim, feature_dim)
nn.init.zeros_(self.scale_fc.weight)
nn.init.zeros_(self.scale_fc.bias)
nn.init.zeros_(self.shift_fc.weight)
nn.init.zeros_(self.shift_fc.bias)
def forward(self, x, condition):
# calculate scale and shift
scale = self.scale_fc(condition)
shift = self.shift_fc(condition)
# film
return x * (1 + scale) + shift

View File

@ -0,0 +1,313 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import torch
import torchvision.transforms as transforms
from safetensors.torch import load_file
from torch import Tensor
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class DexVLAPolicy(PreTrainedPolicy):
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
config_class = DexVLAConfig
name = "dexvla"
def __init__(
self,
config: DexVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]:
setattr(config.qwen2_vla_config, k, config.__dict__[k])
# if self.config.training_stage == 2:
# self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16)
model_base = self.config.qwen2_vl_path
self.model = AutoModelForCausalLM.from_pretrained(
model_base,
config=config.qwen2_vla_config,
trust_remote_code=True,
_fast_init=False,
# attn_implementation="flash_attention_2",
).to(device="cuda", dtype=torch.bfloat16)
if self.config.pretrained_scaledp_path is not None:
print(
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path)
keys_to_del_dit = []
pretrain_scaledp_weights = {
k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items()
}
for k in pretrain_scaledp_weights:
if "noise_pred" not in k: # del weights of vision backbones
keys_to_del_dit.append(k)
if "cond_obs_emb" in k:
keys_to_del_dit.append(k)
for k in keys_to_del_dit:
del pretrain_scaledp_weights[k]
pretrain_scaledp_weights = {
k[15:] if k.startswith("noise_pred_net.") else k: v
for k, v in pretrain_scaledp_weights.items()
}
self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)
self.model.requires_grad_(False)
self.model.policy_head.requires_grad_(True)
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path)
self.vla_processor = Qwen2VLAProcess(
tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor
) # process the input data into VLM format
self.resize_size = self.config.resize_size
ratio = 0.95
self.transformations = [
transforms.Resize(size=self.resize_size, antialias=True),
transforms.RandomCrop(size=[int(self.resize_size[0] * ratio), int(self.resize_size[1] * ratio)]),
transforms.Resize(self.resize_size, antialias=True),
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), # , hue=0.08)
]
self.reset()
def process_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Applying DexVLA preprocessing to original data. Including resizing images. Scaling the range of actions, states."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
present_img_keys = [key for key in self.config.image_features if key in batch]
task_descs = batch["task"]
try:
reasonings = batch["reasoning"]
except KeyError:
reasonings = ["None."] * len(task_descs)
pass
is_pad = batch["action_is_pad"]
all_cam_images = []
for k in present_img_keys:
all_cam_images.append(batch[k])
# construct observations, and scale 0-1 to 0-255
image_data = torch.stack(all_cam_images) * 255
image_data = image_data.to(dtype=torch.uint8)
# construct observations
qpos_data = batch["observation.state"].float()
action_data = batch["action"].float()
orig_shape = image_data.shape
image_data = image_data.view(-1, *orig_shape[2:])
for transform in self.transformations:
image_data = transform(image_data)
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings}
# processing vl_data into qwen2_vl format
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
vla_inputs["states"] = qpos_data
vla_inputs["is_pad"] = is_pad
vla_inputs["actions"] = action_data
return vla_inputs
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
processed_batch = self.process_batch(batch)
ret = self.model.forward(**processed_batch)
loss_dict = ret["loss"]
loss = loss_dict["loss"].mean()
return loss, loss_dict
def dexvla_predict_action(
self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
tokenizer=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_spatiotemporal=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
outputs = self.model.generate(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_spatiotemporal=image_grid_spatiotemporal,
is_eval=is_eval,
num_beams=1,
do_sample=False,
temperature=0.2,
max_new_tokens=256,
eos_token_id=tokenizer.eos_token_id, # End of sequence token
pad_token_id=tokenizer.eos_token_id, # Pad token
use_cache=True,
output_hidden_states=True,
return_dict_in_generate=True,
)
output_ids = outputs.sequences
# last_hidden_states = outputs.hidden_states[-2][-1]
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
outputs_text = outputs_text.strip()
last_hidden_states = [each[-1] for each in outputs.hidden_states] # all hidden states
all_hidden_states = torch.cat(last_hidden_states, dim=1)
action_hidden_states = None
labels_input = torch.ones((1, input_token_len)) * -100
labels_output = torch.ones((1, output_ids.shape[1] - input_token_len))
labels = torch.cat([labels_input, labels_output], dim=1)
if self.model.using_film:
action_hidden_states = self.model.film_forward(
labels=labels,
input_ids=output_ids,
hidden_states=torch.cat(last_hidden_states, dim=1),
)
action = self.model.policy_head(
actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad
)
return action, outputs_text
def tinyvla_predict_action(
self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_spatiotemporal=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
all_hidden_states = self.model.forward(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_spatiotemporal=image_grid_spatiotemporal,
is_eval=is_eval,
tinyvla=True,
)
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
action = self.model.policy_head(
actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad
)
return action, "tinyvla generates no reasoning"
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
batch = self.normalize_inputs(batch)
if len(self._action_queue) == 0:
present_img_keys = [key for key in self.config.image_features if key in batch]
try:
task_descs = batch["task"]
except KeyError:
task_descs = " "
print("No task descriptions found for this task")
all_cam_images = []
for k in present_img_keys:
all_cam_images.append(batch[k])
# construct observations, and scale 0-1 to 0-255
image_data = torch.stack(all_cam_images) * 255
image_data = image_data.to(dtype=torch.uint8)
# construct observations
qpos_data = batch["observation.state"].float()
image_data = image_data.squeeze(0)
for transform in self.transformations:
image_data = transform(image_data)
# processing vl_data into qwen2_vl format
vla_inputs = self.vla_processor.single_forward_process(
images=image_data, raw_lang=task_descs, reasoning=None, eval=True
)
vla_inputs["states"] = qpos_data
if self.config.using_film and self.config.with_llm_head: # dexvla
all_actions, outputs = self.dexvla_predict_action(
**vla_inputs, is_eval=True, tokenizer=self.tokenizer
)
else: # tinyvla
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
actions = self.unnormalize_outputs({"action": all_actions})["action"]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoConfig, AutoModel
from .configuration_scaledp import ScaleDPPolicyConfig
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
from .modeling_scaledp import ScaleDP
from .modeling_unet_diffusion import ConditionalUnet1D
def register_policy_heads():
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
MODEL_STRUCTURE = {
"scaledp_h": {
"depth": 32,
"n_emb": 1280,
"num_heads": 16,
},
"scaledp_l": {
"depth": 24,
"n_emb": 1024,
"num_heads": 16,
}, # 400M
}
class ScaleDPPolicyConfig(PretrainedConfig):
"""
Configuration for ScaleDP policy head
"""
model_type = "scale_dp_policy"
def __init__(
self,
eval: bool = False,
action_dim: int = 14, # action dim
# output_dim: int = 14, # action dim
cond_dim: int = 1536, # the input dim of the condition
state_dim: int = 14, # the input dim of the state
prediction_horizon: int = 16, # horizon
n_obs_steps: int = 2, # number of observation steps
depth: int = 28, # number of DiT blocks
n_emb: int = 256, # embedding size
num_heads: int = 16,
mlp_ratio: int = 4.0,
time_as_cond: bool = True,
obs_as_cond: bool = True,
learn_sigma: bool = False,
model_size: str = "none",
num_inference_timesteps: int = 10,
noise_samples: int = 1,
num_train_timesteps: int = 100,
**kwargs,
):
if model_size != "none":
depth = MODEL_STRUCTURE[model_size]["depth"]
n_emb = MODEL_STRUCTURE[model_size]["n_emb"]
num_heads = MODEL_STRUCTURE[model_size]["num_heads"]
else:
# raise ValueError("model_size show not be 'none'")
pass
# print("model_size should not be 'none'")
self.eval = eval
self.input_dim = action_dim
self.output_dim = action_dim
self.prediction_horizon = prediction_horizon
self.cond_dim = cond_dim
self.state_dim = state_dim
self.n_obs_steps = n_obs_steps
self.depth = depth
self.n_emb = n_emb
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.learn_sigma = learn_sigma
self.num_inference_timesteps = num_inference_timesteps
self.num_queries = prediction_horizon
self.noise_samples = noise_samples
self.num_train_timesteps = num_train_timesteps
super().__init__(**kwargs)
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from CLIPConfig
if config_dict.get("model_type") == "llava_pythia":
config_dict = config_dict["action_head"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class UnetDiffusionPolicyConfig(PretrainedConfig):
"""
Configuration for dit diffusion policy head
"""
model_type = "unet_diffusion_policy"
def __init__(
self,
action_dim=10,
global_cond_dim=2048,
diffusion_step_embed_dim=256,
down_dims=None,
kernel_size=5,
n_groups=8,
state_dim=7,
prediction_horizon=16,
noise_samples=1,
num_inference_timesteps=10,
num_train_timesteps=100,
**kwargs,
):
if down_dims is None:
down_dims = [256, 512, 1024]
self.input_dim = action_dim
self.noise_samples = noise_samples
self.prediction_horizon = prediction_horizon
self.num_inference_timesteps = num_inference_timesteps
self.global_cond_dim = global_cond_dim
self.diffusion_step_embed_dim = diffusion_step_embed_dim
self.down_dims = down_dims
self.kernel_size = kernel_size
self.n_groups = n_groups
self.state_dim = state_dim
self.num_train_timesteps = num_train_timesteps
super().__init__(**kwargs)
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from CLIPConfig
if config_dict.get("model_type") == "llava_pythia":
config_dict = config_dict["action_head"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)

View File

@ -0,0 +1,561 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.utils.checkpoint
from timm.models.vision_transformer import Mlp, use_fused_attn
from torch.jit import Final
from transformers.modeling_utils import PreTrainedModel
from .configuration_scaledp import ScaleDPPolicyConfig
_logger = logging.getLogger(__name__)
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = func.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
# attn = q @ k.transpose(-2, -1)
# if attn_mask is not None:
# attn += attn_mask
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = attn @ v
attn_scores = torch.matmul(q, k.transpose(-2, -1))
# Add attention mask if provided
if attn_mask is not None:
attn_scores += attn_mask
# Apply softmax to get attention weights (softmax is applied along the last dimension)
attn_weights = func.softmax(attn_scores, dim=-1)
# Dropout on attention weights (if dropout is used)
attn_weights = self.attn_drop(attn_weights)
# Apply attention weights to value tensor (V)
x = torch.matmul(attn_weights, v)
x = x.transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
logger = logging.getLogger(__name__)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.bfloat16) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(dtype=torch.bfloat16)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
#################################################################################
# Core ScaleDP Model #
#################################################################################
class ScaleDPBlock(nn.Module):
"""
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
def approx_gelu():
return nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c, attn_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
6, dim=1
)
x = x + gate_msa.unsqueeze(1) * self.attn(
modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
) # norm, scale&shift, attn, scale,
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of ScaleDP.
"""
def __init__(self, hidden_size, output_dim):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class ScaleDP(PreTrainedModel):
"""
Diffusion models with a Transformer backbone.
"""
config_class = ScaleDPPolicyConfig
def __init__(
self,
config: ScaleDPPolicyConfig,
):
super().__init__(config)
# compute number of tokens for main trunk and conScaleDPion encoder
if config.n_obs_steps is None:
config.n_obs_steps = config.prediction_horizon
t = config.prediction_horizon
t_cond = 1
if not config.time_as_cond:
t += 1
t_cond -= 1
obs_as_cond = config.cond_dim > 0
if obs_as_cond:
assert config.time_as_cond
t_cond += config.n_obs_steps
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
self.combine = nn.Sequential(
nn.Linear(config.cond_dim + config.state_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, config.cond_dim),
)
self.learn_sigma = config.learn_sigma
self.input_dim = config.input_dim
self.output_dim = config.output_dim * 2 if config.learn_sigma else config.output_dim
self.num_heads = config.num_heads
self.x_embedder = nn.Linear(config.input_dim, config.n_emb)
self.t_embedder = TimestepEmbedder(config.n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(config.cond_dim, config.n_emb)
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
self.blocks = nn.ModuleList(
[
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio)
for _ in range(config.depth)
]
)
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
# self.initialize_weights()
# constants
self.t = t
self.t_cond = t_cond
self.prediction_horizon = config.prediction_horizon
self.time_as_cond = config.time_as_cond
self.action_dim = config.output_dim
self.obs_as_cond = obs_as_cond
logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()))
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
self.num_inference_timesteps = config.num_inference_timesteps
# self.proj_to_action = nn.Identity()
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon",
)
self.num_queries = config.num_queries # 16
self.noise_samples = config.noise_samples # 1
# self.num_inference_timesteps = config.num_inference_timesteps # 100
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
nn.init.normal_(self.pos_embed, mean=0.0, std=0.02)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.cond_obs_emb.weight, mean=0.0, std=0.02)
nn.init.constant_(self.cond_obs_emb.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in ScaleDP blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def get_optim_groups(self, weight_decay: float = 1e-3):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the models into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, Attention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _p in m.named_parameters():
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.startswith("bias"):
# MultiheadAttention bias starts with "bias"
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
)
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(decay)],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(no_decay)],
"weight_decay": 0.0,
},
]
return optim_groups
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def forward(self, actions, hidden_states, states, is_pad):
"""
Forward pass for the diffusion head.
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
:param states: robot states, shape [b, D]
:return: loss
"""
if actions is not None: # training time
b = actions.size(0)
actions = actions[:, : self.num_queries]
is_pad = is_pad[:, : self.num_queries]
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, b, Ta, D(1, 2, 16, 14)
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * b, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
timesteps = timesteps.repeat(num_noise_samples)
is_pad = is_pad.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
noise_pred = self.model_forward(
noisy_actions, timesteps, global_cond=hidden_states, states=states
)
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
# loss_dict['loss'] = loss
return {"loss": loss}
# return loss
else: # inference time
b = 1
tp = self.num_queries
action_dim = self.action_dim
# initialize action from Gaussian noise
noisy_action = torch.randn((b, tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)
# init scheduler
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# predict noise
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
def model_forward(self, x, t, global_cond, states):
"""
Forward pass of ScaleDP.
x: (N, T, input_dim) noisy actions
t: (N,) tensor of diffusion timesteps
global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings
"""
global_cond = global_cond.squeeze(1)
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
global_cond = self.combine(global_cond)
if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
elif torch.is_tensor(t) and len(t.shape) == 0:
t = t[None].to(x.device)
t = t.expand(t.shape[0])
x = self.x_embedder(x) + self.pos_embed.to(
device=x.device, dtype=x.dtype
) # (N, T, D), where T = prediction_horizon
t = self.t_embedder(t) # (N, D)
if self.obs_as_cond:
global_cond = self.cond_obs_emb(global_cond) # (N, D)
# c = t + global_cond.sum(dim=1) # (N, D)
c = t + global_cond # (N, D)
for block in self.blocks:
# x = block(x, c, attn_mask=self.mask) # (N, T, D)
x = block(x, c, attn_mask=None) # (N, T, D)
x = self.final_layer(x, c) # (N, T, output_dim)
return x
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# ScaleDP Configs #
#################################################################################
def scaledp_h(**kwargs):
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
def scaledp_l(**kwargs):
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)

View File

@ -0,0 +1,387 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import Union
import torch
import torch.nn as nn
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers.modeling_utils import PreTrainedModel
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
# =================== UNet for Diffusion ==============
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, dtype):
super().__init__()
self.dim = dim
self.dtype = dtype
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(PreTrainedModel):
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
config_class = UnetDiffusionPolicyConfig
def __init__(self, config: UnetDiffusionPolicyConfig):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines number of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
super().__init__(config)
all_dims = [config.input_dim] + list(config.down_dims)
start_dim = config.down_dims[0]
self.num_queries = config.prediction_horizon
self.noise_samples = config.noise_samples
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
dsed = config.diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed, torch.bfloat16),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + config.global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=config.kernel_size,
n_groups=config.n_groups,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
nn.Conv1d(start_dim, config.input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
self.num_inference_timesteps = config.num_inference_timesteps
# self.proj_to_action = nn.Identity()
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=config.num_train_timesteps, # 100
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon",
)
# self.num_inference_timesteps = config.num_inference_timesteps # 100
def forward(self, actions, hidden_states, states, is_pad):
"""
Forward pass for the diffusion head.
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
:param states: robot states, shape [b, D]
:return: loss
"""
if actions is not None: # training time
b = actions.size(0)
actions = copy.deepcopy(actions[:, : self.num_queries])
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, b, Ta, D
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * b, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
timesteps = timesteps.repeat(num_noise_samples)
is_pad = is_pad.repeat(num_noise_samples, 1)
states = states.repeat(num_noise_samples, 1)
noise_pred = self.model_forward(
noisy_actions, timesteps, global_cond=hidden_states, states=states
)
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
# loss_dict['loss'] = loss
return {"loss": loss}
# return loss
else: # inference time
b = 1
tp = self.num_queries
action_dim = 14
# initialize action from Gaussian noise
noisy_action = torch.randn((b, tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)
# init scheduler
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# predict noise
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
def model_forward(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
):
"""
x: (b,T,input_dim)
timestep: (b,) or int, diffusion step
global_cond: (b,global_cond_dim)
output: (b,T,input_dim)
"""
# (b,t,c)
sample = sample.moveaxis(-1, -2)
# (b,c,t)
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
global_cond = global_cond.squeeze(1)
global_cond = self.norm_after_pool(global_cond)
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
global_cond = self.combine(global_cond)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
x = sample
h = []
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
# (b,c,t)
x = x.moveaxis(-1, -2)
# (b,t,c)
return x

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoConfig, AutoModelForCausalLM
from .configuration_qwen2_vla import Qwen2VLAConfig
from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
def register_qwen2_vla():
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)

View File

@ -0,0 +1,254 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Qwen2VLAConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2VLModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 29568):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 80):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
>>> # Initializing a Qwen2VL style configuration
>>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_vla"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
# For loading policy head
policy_head_type="scale_dp_policy", # unet_diffusion_policy
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.policy_head_type = policy_head_type # for loading policy head
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self, ignore_keys={"mrope_section"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,172 @@
#!/usr/bin/env python
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from PIL import Image
from qwen_vl_utils import fetch_image
class Qwen2VLAProcess:
def __init__(
self,
tokenizer=None,
max_seq_len=512,
multimodal_processor=None,
):
super().__init__()
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.multimodal_processor = multimodal_processor
def qwen2_image_preprocess(self, each):
ele = {}
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
ele["image"] = each
ele["resized_height"] = each.height
ele["resized_width"] = each.width
each = fetch_image(ele)
return torch.from_numpy(np.array(each))
def single_forward_process(self, images, raw_lang, reasoning, eval=False, use_reasoning=True):
len_views = images.shape[0]
messages = self.construct_chat_data(len_views, raw_lang)
data_dict = {"messages": messages}
image_data = torch.chunk(images, len_views, 0)
images_list = []
for _i, each in enumerate(image_data):
img_pil = self.qwen2_image_preprocess(each)
images_list.append(img_pil)
text = self.multimodal_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = self.multimodal_processor(
text=text,
images=images_list,
videos=None,
padding=True,
return_tensors="pt",
)
if eval:
new_dict = {}
for k, v in model_inputs.items():
if "image_grid" in k:
new_dict["image_grid_spatiotemporal"] = v
else:
new_dict[k] = v
return new_dict
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
output_labels = output_text["input_ids"]
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1)
model_inputs["attention_mask"] = torch.cat(
(model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1
)
labels = torch.cat((input_labels, output_labels), dim=-1)
data_dict["labels"] = labels
for k, v in model_inputs.items():
if "image_grid" in k:
data_dict["image_grid_spatiotemporal"] = v
else:
data_dict[k] = v
return data_dict
def forward(self, batch, use_reasoning=True):
"""This is the main process function for processing vl data into Qwen2_vl format"""
all_images = batch["images"]
all_images = torch.einsum(
"v b c h w -> b v c h w", all_images
) # camera_views, batch_size, channel, height, width
ret_l = []
for idx, images in enumerate(all_images):
raw_lang = batch["raw_langs"][idx]
reasoning = batch["reasonings"][idx]
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
ret_l.append(ret_dict)
return self.post_process(ret_l)
def post_process(self, instances):
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
image_grid_spatiotemporal = torch.stack(
[instances["image_grid_spatiotemporal"] for instances in instances]
)
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
pixel_values_videos = None
video_grid_spatiotemporal = None
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
labels = torch.flip(labels, dims=[1])
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
input_ids = torch.flip(input_ids, dims=[1])
b = input_ids.shape[0]
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
)
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask[0],
"labels": labels,
"image_grid_spatiotemporal": image_grid_spatiotemporal,
"pixel_values_videos": pixel_values_videos,
"video_grid_spatiotemporal": video_grid_spatiotemporal,
"pixel_values": pixel_values,
}
return batch
def construct_chat_data(self, len_image, raw_lang):
messages = [
{
"role": "user",
"content": [],
},
]
for _i in range(len_image):
messages[0]["content"].append(
{
"type": "image",
"image": None,
}
)
messages[0]["content"].append({"type": "text", "text": ""})
messages[0]["content"][-1]["text"] = raw_lang
return messages

View File

@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "dexvla":
from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy
return DexVLAPolicy
elif name == "pi0fast":
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
@ -74,6 +79,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "dexvla":
return DexVLAConfig(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
else:

View File

@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
dexvla = ["transformers>=4.45.2", "qwen_vl_utils==0.0.10", "timm==0.9.10"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",