Merge b1c1d395c1
into 5322417c03
This commit is contained in:
commit
5d4d9df58f
|
@ -111,6 +111,32 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Used by DexVLA to train Stage2"""
|
||||
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
def linear_warmup_schedule(current_step):
|
||||
if current_step <= 0:
|
||||
return 1 / (self.num_warmup_steps + 1)
|
||||
frac = 1 - current_step / self.num_warmup_steps
|
||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
||||
|
||||
def constant_schedule(current_step):
|
||||
return 1
|
||||
|
||||
if current_step < self.num_warmup_steps:
|
||||
return linear_warmup_schedule(current_step)
|
||||
|
||||
return constant_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
<h1 align="center">
|
||||
DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning</h1>
|
||||
|
||||
This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA).
|
||||
This is [project website](https://dex-vla.github.io/).
|
||||
|
||||
## Dataset
|
||||
### Data format
|
||||
DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras.
|
||||
|
||||
⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations.
|
||||
So you have to <font color='red'>add sub-step reasoning in your data for training</font>.
|
||||
|
||||
Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation.
|
||||
For example, if the episode is 10 steps. The length of this list should be 10 as well. And it may looks like:
|
||||
~~~python
|
||||
reasoning = [
|
||||
"This is step 1.",
|
||||
"This is step 1.",
|
||||
"This is step 2.",
|
||||
"This is step 2.",
|
||||
...
|
||||
"This is step 4.",
|
||||
]
|
||||
~~~
|
||||
|
||||
Besides, your data should include another key ``action_is_pad`` which is a bool mask indicating whether this action chunk is padded.
|
||||
Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5.
|
||||
And the mask looks like:
|
||||
~~~python
|
||||
The 6th chunk: [false, false, false, false, true]
|
||||
The 7th chunk: [false, false, false, true, true]
|
||||
The 8th chunk: [false, false, true, true, true]
|
||||
The 9th chunk: [false, true, true, true, true]
|
||||
~~~
|
||||
|
||||
### Training Data for DexVLA
|
||||
The pretraining dataset comprises approximately 100 hours of collected data by ourselves. The dataset mainly including four embodiments which are: moblie Agilex Aloha, single Franka Emika and single UR5e.
|
||||
We haven't use any public dataset such as Open-X or DROID.
|
||||
|
||||
## 🤗Download Pretrained Weights
|
||||
### Download official Qwen2_VL weights
|
||||
We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework.
|
||||
The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities
|
||||
for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed
|
||||
in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link:
|
||||
|
||||
| Model | Link |
|
||||
|---------------------|----------------------------------------------------------------|
|
||||
| Qwen2-VL (~2B) | [huggingface](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) |
|
||||
|
||||
**❗❗** After downloading the standard weights, you have to replace the official "config.json"
|
||||
with our ["config.json"](https://github.com/juruobenruo/DexVLA/blob/main/docs/config.json) designed for VLA.
|
||||
### Download our pretrained ScaleDP-H weights(Stage 1)
|
||||
We released our pretrained weights of ScaleDP-H which is trained after Stage1. Now you can download the weights and directly finetuning your data on Stage 2.
|
||||
|
||||
| Model | Link |
|
||||
|-------------------|----------------------------------------------------------------|
|
||||
| ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) |
|
||||
| ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) |
|
||||
|
||||
**❗❗**After downloading the weights, you have to transform it into ``safetensors`` format, you can simply run this code:
|
||||
~~~python
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
path = "/path/to/open_scale_dp_l_backbone.ckpt"
|
||||
checkpoint = torch.load(path, map_location=torch.device('cpu'))['nets']['nets']
|
||||
|
||||
# Save the weights in safetensors format
|
||||
safetensors_path = "/path/to/open_scale_dp_l_backbone.safetensors"
|
||||
save_file(checkpoint, safetensors_path)
|
||||
print(f"Converted {path} to {safetensors_path}")
|
||||
pass
|
||||
|
||||
~~~
|
||||
|
||||
## 🦾Train
|
||||
We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3.
|
||||
|
||||
### Training Stage 2
|
||||
~~~shell
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type dexvla \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 2 \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Stage 3
|
||||
Stage3 can be viewed as continual training on specific dexterous tasks like laundry folding which is same as PI0. So stage3 is trained based on stage2.
|
||||
~~~shell
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type dexvla \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--.pretrained_path /path/to/pretrained/stage2/weights \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 3 \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--batch_size 2 \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Time
|
||||
Original DexVLA is trained on 8 x H100 GPUs. And the training time for each stage is listed as follows:
|
||||
|
||||
| Stage | Batch Size(each gpu) | Steps | Time(hour) |
|
||||
|--------|----------------------|--------|------------|
|
||||
| Stage1 | 32 | 60000 | 30 |
|
||||
| Stage2 | 12 | 100000 | 30 |
|
||||
| Stage3 | 12 | 60000 | 18 |
|
||||
|
||||
|
||||
## Evaluation
|
||||
### Evaluation Script
|
||||
You can evaluate dexvla by following scripts.
|
||||
~~~shell
|
||||
python lerobot/scripts/eval.py \
|
||||
--policy.type dexvla \
|
||||
--policy.pretrained_path /path/to/pretrained/stage2/or/stage3/weights \
|
||||
--env.type aloha \
|
||||
--env.episode_length 5 \
|
||||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--env.task AlohaInsertion-v0 \
|
||||
--eval.n_episodes 1 \
|
||||
--eval.batch_size 1
|
||||
~~~
|
||||
|
||||
### Inference Speed
|
||||
Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz.
|
|
@ -0,0 +1,179 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Qwen2VL model configuration"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
ConstantWithWarmupSchedulerConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
register_policy_heads()
|
||||
register_qwen2_vla()
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dexvla")
|
||||
@dataclass
|
||||
class DexVLAConfig(PreTrainedConfig):
|
||||
# For loading policy head
|
||||
policy_head_type: str = "scale_dp_policy"
|
||||
policy_head_size: str = "scaledp_l"
|
||||
action_dim: int = 14
|
||||
state_dim: int = 14
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
n_obs_steps: int = 1
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
hidden_size: int = 1536
|
||||
qwen2_vl_path: str = (
|
||||
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
|
||||
)
|
||||
|
||||
pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3
|
||||
pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1)
|
||||
|
||||
training_stage: int = 2 # specific training stage, [2, 3]
|
||||
using_film: bool = True
|
||||
llm_loss_weight: float = 1.0
|
||||
with_llm_head: bool = True
|
||||
using_reasoning: bool = True
|
||||
resize_size: tuple = (240, 320)
|
||||
# Training presets
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: Tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
|
||||
scheduler_warmup_steps: int = 2_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
# "VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if self.using_reasoning:
|
||||
assert self.using_film, "using_reasoning requires `using_film=True`"
|
||||
assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`"
|
||||
print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.")
|
||||
else:
|
||||
print(
|
||||
"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'."
|
||||
)
|
||||
|
||||
if self.qwen2_vl_path is None:
|
||||
raise ValueError(
|
||||
"DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'."
|
||||
)
|
||||
|
||||
if self.policy_head_type == "scale_dp_policy":
|
||||
self.policy_head_config = AutoConfig.for_model(
|
||||
model_type=self.policy_head_type,
|
||||
model_size=self.policy_head_size,
|
||||
cond_dim=self.hidden_size,
|
||||
action_dim=self.action_dim,
|
||||
prediction_horizon=self.chunk_size,
|
||||
state_dim=self.state_dim,
|
||||
)
|
||||
elif self.policy_head_type == "unet_diffusion":
|
||||
self.policy_head_config = AutoConfig.for_model(
|
||||
model_type=self.policy_head_type,
|
||||
global_cond_dim=self.hidden_size,
|
||||
action_dim=self.action_dim,
|
||||
state_dim=self.state_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Policy head type {self.policy_head_type} not supported")
|
||||
|
||||
if self.training_stage not in [2, 3]:
|
||||
raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.")
|
||||
|
||||
self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
# for i in range(self.empty_cameras):
|
||||
# key = f"observation.images.empty_camera_{i}"
|
||||
# empty_camera = PolicyFeature(
|
||||
# type=FeatureType.VISUAL,
|
||||
# shape=(3, 480, 640),
|
||||
# )
|
||||
# self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
if self.training_stage == 3:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
else:
|
||||
return ConstantWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ActionProjector(nn.Module):
|
||||
def __init__(self, in_dim, out_dim=1024):
|
||||
super().__init__()
|
||||
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||
self.mlps = nn.ModuleList(
|
||||
[
|
||||
# nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, in_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.Dropout(0.0),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_1d_pool(x.permute(1, 0)).permute(1, 0)
|
||||
for mlp in self.mlps:
|
||||
x = mlp(x)
|
||||
return x
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, feature_dim, condition_dim):
|
||||
super().__init__()
|
||||
self.scale_fc = nn.Linear(condition_dim, feature_dim)
|
||||
self.shift_fc = nn.Linear(condition_dim, feature_dim)
|
||||
|
||||
nn.init.zeros_(self.scale_fc.weight)
|
||||
nn.init.zeros_(self.scale_fc.bias)
|
||||
nn.init.zeros_(self.shift_fc.weight)
|
||||
nn.init.zeros_(self.shift_fc.bias)
|
||||
|
||||
def forward(self, x, condition):
|
||||
# calculate scale and shift
|
||||
scale = self.scale_fc(condition)
|
||||
shift = self.shift_fc(condition)
|
||||
|
||||
# film
|
||||
return x * (1 + scale) + shift
|
|
@ -0,0 +1,313 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from safetensors.torch import load_file
|
||||
from torch import Tensor
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||
|
||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||
from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class DexVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = DexVLAConfig
|
||||
name = "dexvla"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DexVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]:
|
||||
setattr(config.qwen2_vla_config, k, config.__dict__[k])
|
||||
|
||||
# if self.config.training_stage == 2:
|
||||
# self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16)
|
||||
model_base = self.config.qwen2_vl_path
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_base,
|
||||
config=config.qwen2_vla_config,
|
||||
trust_remote_code=True,
|
||||
_fast_init=False,
|
||||
# attn_implementation="flash_attention_2",
|
||||
).to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if self.config.pretrained_scaledp_path is not None:
|
||||
print(
|
||||
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
||||
)
|
||||
pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path)
|
||||
|
||||
keys_to_del_dit = []
|
||||
pretrain_scaledp_weights = {
|
||||
k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
for k in pretrain_scaledp_weights:
|
||||
if "noise_pred" not in k: # del weights of vision backbones
|
||||
keys_to_del_dit.append(k)
|
||||
if "cond_obs_emb" in k:
|
||||
keys_to_del_dit.append(k)
|
||||
for k in keys_to_del_dit:
|
||||
del pretrain_scaledp_weights[k]
|
||||
pretrain_scaledp_weights = {
|
||||
k[15:] if k.startswith("noise_pred_net.") else k: v
|
||||
for k, v in pretrain_scaledp_weights.items()
|
||||
}
|
||||
|
||||
self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False)
|
||||
|
||||
self.model.requires_grad_(False)
|
||||
self.model.policy_head.requires_grad_(True)
|
||||
self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path)
|
||||
self.vla_processor = Qwen2VLAProcess(
|
||||
tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor
|
||||
) # process the input data into VLM format
|
||||
|
||||
self.resize_size = self.config.resize_size
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.Resize(size=self.resize_size, antialias=True),
|
||||
transforms.RandomCrop(size=[int(self.resize_size[0] * ratio), int(self.resize_size[1] * ratio)]),
|
||||
transforms.Resize(self.resize_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), # , hue=0.08)
|
||||
]
|
||||
|
||||
self.reset()
|
||||
|
||||
def process_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Applying DexVLA preprocessing to original data. Including resizing images. Scaling the range of actions, states."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
task_descs = batch["task"]
|
||||
try:
|
||||
reasonings = batch["reasoning"]
|
||||
except KeyError:
|
||||
reasonings = ["None."] * len(task_descs)
|
||||
|
||||
pass
|
||||
is_pad = batch["action_is_pad"]
|
||||
all_cam_images = []
|
||||
for k in present_img_keys:
|
||||
all_cam_images.append(batch[k])
|
||||
|
||||
# construct observations, and scale 0-1 to 0-255
|
||||
image_data = torch.stack(all_cam_images) * 255
|
||||
image_data = image_data.to(dtype=torch.uint8)
|
||||
# construct observations
|
||||
qpos_data = batch["observation.state"].float()
|
||||
action_data = batch["action"].float()
|
||||
|
||||
orig_shape = image_data.shape
|
||||
image_data = image_data.view(-1, *orig_shape[2:])
|
||||
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
image_data = image_data.view(*orig_shape[:3], *self.resize_size)
|
||||
|
||||
vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings}
|
||||
# processing vl_data into qwen2_vl format
|
||||
vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning)
|
||||
vla_inputs["states"] = qpos_data
|
||||
vla_inputs["is_pad"] = is_pad
|
||||
vla_inputs["actions"] = action_data
|
||||
return vla_inputs
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
processed_batch = self.process_batch(batch)
|
||||
|
||||
ret = self.model.forward(**processed_batch)
|
||||
loss_dict = ret["loss"]
|
||||
loss = loss_dict["loss"].mean()
|
||||
return loss, loss_dict
|
||||
|
||||
def dexvla_predict_action(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
actions=None,
|
||||
states=None,
|
||||
is_pad=None,
|
||||
tokenizer=None,
|
||||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
temperature=0.2,
|
||||
max_new_tokens=256,
|
||||
eos_token_id=tokenizer.eos_token_id, # End of sequence token
|
||||
pad_token_id=tokenizer.eos_token_id, # Pad token
|
||||
use_cache=True,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
output_ids = outputs.sequences
|
||||
# last_hidden_states = outputs.hidden_states[-2][-1]
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
|
||||
outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0]
|
||||
|
||||
outputs_text = outputs_text.strip()
|
||||
last_hidden_states = [each[-1] for each in outputs.hidden_states] # all hidden states
|
||||
all_hidden_states = torch.cat(last_hidden_states, dim=1)
|
||||
|
||||
action_hidden_states = None
|
||||
labels_input = torch.ones((1, input_token_len)) * -100
|
||||
labels_output = torch.ones((1, output_ids.shape[1] - input_token_len))
|
||||
labels = torch.cat([labels_input, labels_output], dim=1)
|
||||
|
||||
if self.model.using_film:
|
||||
action_hidden_states = self.model.film_forward(
|
||||
labels=labels,
|
||||
input_ids=output_ids,
|
||||
hidden_states=torch.cat(last_hidden_states, dim=1),
|
||||
)
|
||||
|
||||
action = self.model.policy_head(
|
||||
actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||
)
|
||||
return action, outputs_text
|
||||
|
||||
def tinyvla_predict_action(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
actions=None,
|
||||
states=None,
|
||||
is_pad=None,
|
||||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
all_hidden_states = self.model.forward(
|
||||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
tinyvla=True,
|
||||
)
|
||||
|
||||
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
|
||||
|
||||
action = self.model.policy_head(
|
||||
actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad
|
||||
)
|
||||
return action, "tinyvla generates no reasoning"
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
try:
|
||||
task_descs = batch["task"]
|
||||
except KeyError:
|
||||
task_descs = " "
|
||||
print("No task descriptions found for this task")
|
||||
|
||||
all_cam_images = []
|
||||
for k in present_img_keys:
|
||||
all_cam_images.append(batch[k])
|
||||
|
||||
# construct observations, and scale 0-1 to 0-255
|
||||
image_data = torch.stack(all_cam_images) * 255
|
||||
image_data = image_data.to(dtype=torch.uint8)
|
||||
# construct observations
|
||||
qpos_data = batch["observation.state"].float()
|
||||
|
||||
image_data = image_data.squeeze(0)
|
||||
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# processing vl_data into qwen2_vl format
|
||||
vla_inputs = self.vla_processor.single_forward_process(
|
||||
images=image_data, raw_lang=task_descs, reasoning=None, eval=True
|
||||
)
|
||||
vla_inputs["states"] = qpos_data
|
||||
|
||||
if self.config.using_film and self.config.with_llm_head: # dexvla
|
||||
all_actions, outputs = self.dexvla_predict_action(
|
||||
**vla_inputs, is_eval=True, tokenizer=self.tokenizer
|
||||
)
|
||||
else: # tinyvla
|
||||
all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True)
|
||||
|
||||
actions = self.unnormalize_outputs({"action": all_actions})["action"]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
from .modeling_scaledp import ScaleDP
|
||||
from .modeling_unet_diffusion import ConditionalUnet1D
|
||||
|
||||
|
||||
def register_policy_heads():
|
||||
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
||||
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
||||
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
||||
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MODEL_STRUCTURE = {
|
||||
"scaledp_h": {
|
||||
"depth": 32,
|
||||
"n_emb": 1280,
|
||||
"num_heads": 16,
|
||||
},
|
||||
"scaledp_l": {
|
||||
"depth": 24,
|
||||
"n_emb": 1024,
|
||||
"num_heads": 16,
|
||||
}, # 400M
|
||||
}
|
||||
|
||||
|
||||
class ScaleDPPolicyConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration for ScaleDP policy head
|
||||
"""
|
||||
|
||||
model_type = "scale_dp_policy"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eval: bool = False,
|
||||
action_dim: int = 14, # action dim
|
||||
# output_dim: int = 14, # action dim
|
||||
cond_dim: int = 1536, # the input dim of the condition
|
||||
state_dim: int = 14, # the input dim of the state
|
||||
prediction_horizon: int = 16, # horizon
|
||||
n_obs_steps: int = 2, # number of observation steps
|
||||
depth: int = 28, # number of DiT blocks
|
||||
n_emb: int = 256, # embedding size
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4.0,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = True,
|
||||
learn_sigma: bool = False,
|
||||
model_size: str = "none",
|
||||
num_inference_timesteps: int = 10,
|
||||
noise_samples: int = 1,
|
||||
num_train_timesteps: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
if model_size != "none":
|
||||
depth = MODEL_STRUCTURE[model_size]["depth"]
|
||||
n_emb = MODEL_STRUCTURE[model_size]["n_emb"]
|
||||
num_heads = MODEL_STRUCTURE[model_size]["num_heads"]
|
||||
else:
|
||||
# raise ValueError("model_size show not be 'none'")
|
||||
pass
|
||||
# print("model_size should not be 'none'")
|
||||
self.eval = eval
|
||||
|
||||
self.input_dim = action_dim
|
||||
self.output_dim = action_dim
|
||||
self.prediction_horizon = prediction_horizon
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
self.state_dim = state_dim
|
||||
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.depth = depth
|
||||
self.n_emb = n_emb
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.learn_sigma = learn_sigma
|
||||
|
||||
self.num_inference_timesteps = num_inference_timesteps
|
||||
self.num_queries = prediction_horizon
|
||||
self.noise_samples = noise_samples
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from CLIPConfig
|
||||
if config_dict.get("model_type") == "llava_pythia":
|
||||
config_dict = config_dict["action_head"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration for dit diffusion policy head
|
||||
"""
|
||||
|
||||
model_type = "unet_diffusion_policy"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_dim=10,
|
||||
global_cond_dim=2048,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
state_dim=7,
|
||||
prediction_horizon=16,
|
||||
noise_samples=1,
|
||||
num_inference_timesteps=10,
|
||||
num_train_timesteps=100,
|
||||
**kwargs,
|
||||
):
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
self.input_dim = action_dim
|
||||
self.noise_samples = noise_samples
|
||||
self.prediction_horizon = prediction_horizon
|
||||
self.num_inference_timesteps = num_inference_timesteps
|
||||
self.global_cond_dim = global_cond_dim
|
||||
self.diffusion_step_embed_dim = diffusion_step_embed_dim
|
||||
self.down_dims = down_dims
|
||||
self.kernel_size = kernel_size
|
||||
self.n_groups = n_groups
|
||||
self.state_dim = state_dim
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from CLIPConfig
|
||||
if config_dict.get("model_type") == "llava_pythia":
|
||||
config_dict = config_dict["action_head"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
|
@ -0,0 +1,561 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as func
|
||||
import torch.utils.checkpoint
|
||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||
from torch.jit import Final
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
||||
b, n, c = x.shape
|
||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = func.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
# attn = q @ k.transpose(-2, -1)
|
||||
# if attn_mask is not None:
|
||||
# attn += attn_mask
|
||||
# attn = attn.softmax(dim=-1)
|
||||
# attn = self.attn_drop(attn)
|
||||
# x = attn @ v
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1))
|
||||
|
||||
# Add attention mask if provided
|
||||
if attn_mask is not None:
|
||||
attn_scores += attn_mask
|
||||
|
||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||
attn_weights = func.softmax(attn_scores, dim=-1)
|
||||
|
||||
# Dropout on attention weights (if dropout is used)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
|
||||
# Apply attention weights to value tensor (V)
|
||||
x = torch.matmul(attn_weights, v)
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.bfloat16) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding.to(dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core ScaleDP Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class ScaleDPBlock(nn.Module):
|
||||
"""
|
||||
A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
def approx_gelu():
|
||||
return nn.GELU(approximate="tanh")
|
||||
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
|
||||
6, dim=1
|
||||
)
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn(
|
||||
modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
|
||||
) # norm, scale&shift, attn, scale,
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of ScaleDP.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, output_dim):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, output_dim, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ScaleDP(PreTrainedModel):
|
||||
"""
|
||||
Diffusion models with a Transformer backbone.
|
||||
"""
|
||||
|
||||
config_class = ScaleDPPolicyConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ScaleDPPolicyConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
# compute number of tokens for main trunk and conScaleDPion encoder
|
||||
if config.n_obs_steps is None:
|
||||
config.n_obs_steps = config.prediction_horizon
|
||||
t = config.prediction_horizon
|
||||
t_cond = 1
|
||||
if not config.time_as_cond:
|
||||
t += 1
|
||||
t_cond -= 1
|
||||
obs_as_cond = config.cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert config.time_as_cond
|
||||
t_cond += config.n_obs_steps
|
||||
|
||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||
self.combine = nn.Sequential(
|
||||
nn.Linear(config.cond_dim + config.state_dim, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, config.cond_dim),
|
||||
)
|
||||
self.learn_sigma = config.learn_sigma
|
||||
self.input_dim = config.input_dim
|
||||
self.output_dim = config.output_dim * 2 if config.learn_sigma else config.output_dim
|
||||
self.num_heads = config.num_heads
|
||||
|
||||
self.x_embedder = nn.Linear(config.input_dim, config.n_emb)
|
||||
self.t_embedder = TimestepEmbedder(config.n_emb)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(config.cond_dim, config.n_emb)
|
||||
|
||||
# Will use fixed sin-cos embedding:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb))
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio)
|
||||
for _ in range(config.depth)
|
||||
]
|
||||
)
|
||||
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
||||
# self.initialize_weights()
|
||||
# constants
|
||||
self.t = t
|
||||
self.t_cond = t_cond
|
||||
self.prediction_horizon = config.prediction_horizon
|
||||
self.time_as_cond = config.time_as_cond
|
||||
self.action_dim = config.output_dim
|
||||
self.obs_as_cond = obs_as_cond
|
||||
logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
# self.proj_to_action = nn.Identity()
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=config.num_train_timesteps, # 100
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
self.num_queries = config.num_queries # 16
|
||||
self.noise_samples = config.noise_samples # 1
|
||||
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
nn.init.normal_(self.pos_embed, mean=0.0, std=0.02)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.cond_obs_emb.weight, mean=0.0, std=0.02)
|
||||
nn.init.constant_(self.cond_obs_emb.bias, 0)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in ScaleDP blocks:
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the models into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _p in m.named_parameters():
|
||||
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith("bias"):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith("bias"):
|
||||
# MultiheadAttention bias starts with "bias"
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
)
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
b = actions.size(0)
|
||||
actions = actions[:, : self.num_queries]
|
||||
is_pad = is_pad[:, : self.num_queries]
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, b, Ta, D(1, 2, 16, 14)
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||
# (this is the forward diffusion process)
|
||||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
|
||||
timesteps = timesteps.repeat(num_noise_samples)
|
||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||
states = states.repeat(num_noise_samples, 1)
|
||||
|
||||
noise_pred = self.model_forward(
|
||||
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||
)
|
||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
# loss_dict['loss'] = loss
|
||||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = self.action_dim
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# predict noise
|
||||
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
|
||||
|
||||
# inverse diffusion step (remove noise)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def model_forward(self, x, t, global_cond, states):
|
||||
"""
|
||||
Forward pass of ScaleDP.
|
||||
x: (N, T, input_dim) noisy actions
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings
|
||||
"""
|
||||
global_cond = global_cond.squeeze(1)
|
||||
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
||||
global_cond = self.combine(global_cond)
|
||||
|
||||
if not torch.is_tensor(t):
|
||||
t = torch.tensor([t], dtype=torch.long, device=x.device)
|
||||
elif torch.is_tensor(t) and len(t.shape) == 0:
|
||||
t = t[None].to(x.device)
|
||||
t = t.expand(t.shape[0])
|
||||
|
||||
x = self.x_embedder(x) + self.pos_embed.to(
|
||||
device=x.device, dtype=x.dtype
|
||||
) # (N, T, D), where T = prediction_horizon
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
if self.obs_as_cond:
|
||||
global_cond = self.cond_obs_emb(global_cond) # (N, D)
|
||||
# c = t + global_cond.sum(dim=1) # (N, D)
|
||||
c = t + global_cond # (N, D)
|
||||
for block in self.blocks:
|
||||
# x = block(x, c, attn_mask=self.mask) # (N, T, D)
|
||||
x = block(x, c, attn_mask=None) # (N, T, D)
|
||||
x = self.final_layer(x, c) # (N, T, output_dim)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ScaleDP Configs #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def scaledp_h(**kwargs):
|
||||
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
|
||||
|
||||
|
||||
def scaledp_l(**kwargs):
|
||||
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
|
@ -0,0 +1,387 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# requires diffusers==0.11.1
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
|
||||
# =================== UNet for Diffusion ==============
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim, dtype):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels * 2
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(PreTrainedModel):
|
||||
_no_split_modules = ["mid_modules", "down_modules", "up_modules"]
|
||||
|
||||
config_class = UnetDiffusionPolicyConfig
|
||||
|
||||
def __init__(self, config: UnetDiffusionPolicyConfig):
|
||||
"""
|
||||
input_dim: Dim of actions.
|
||||
global_cond_dim: Dim of global conditioning applied with FiLM
|
||||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
||||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
||||
down_dims: Channel size for each UNet level.
|
||||
The length of this array determines number of levels.
|
||||
kernel_size: Conv kernel size
|
||||
n_groups: Number of groups for GroupNorm
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
all_dims = [config.input_dim] + list(config.down_dims)
|
||||
start_dim = config.down_dims[0]
|
||||
|
||||
self.num_queries = config.prediction_horizon
|
||||
self.noise_samples = config.noise_samples
|
||||
# self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||
# self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim)
|
||||
self.norm_after_pool = nn.LayerNorm(config.global_cond_dim)
|
||||
self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim)
|
||||
dsed = config.diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed, torch.bfloat16),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed + config.global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=config.kernel_size,
|
||||
n_groups=config.n_groups,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size),
|
||||
nn.Conv1d(start_dim, config.input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters())))
|
||||
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
# self.proj_to_action = nn.Identity()
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=config.num_train_timesteps, # 100
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
# self.num_inference_timesteps = config.num_inference_timesteps # 100
|
||||
|
||||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
b = actions.size(0)
|
||||
actions = copy.deepcopy(actions[:, : self.num_queries])
|
||||
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, b, Ta, D
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
||||
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
|
||||
# (this is the forward diffusion process)
|
||||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
hidden_states = hidden_states.repeat(num_noise_samples, 1, 1)
|
||||
timesteps = timesteps.repeat(num_noise_samples)
|
||||
is_pad = is_pad.repeat(num_noise_samples, 1)
|
||||
states = states.repeat(num_noise_samples, 1)
|
||||
|
||||
noise_pred = self.model_forward(
|
||||
noisy_actions, timesteps, global_cond=hidden_states, states=states
|
||||
)
|
||||
noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
# loss_dict['loss'] = loss
|
||||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = 14
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# predict noise
|
||||
noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states)
|
||||
|
||||
# inverse diffusion step (remove noise)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def model_forward(
|
||||
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
|
||||
):
|
||||
"""
|
||||
x: (b,T,input_dim)
|
||||
timestep: (b,) or int, diffusion step
|
||||
global_cond: (b,global_cond_dim)
|
||||
output: (b,T,input_dim)
|
||||
"""
|
||||
# (b,t,c)
|
||||
sample = sample.moveaxis(-1, -2)
|
||||
# (b,c,t)
|
||||
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
global_cond = self.norm_after_pool(global_cond)
|
||||
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
||||
global_cond = self.combine(global_cond)
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
# (b,c,t)
|
||||
x = x.moveaxis(-1, -2)
|
||||
# (b,t,c)
|
||||
return x
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from .configuration_qwen2_vla import Qwen2VLAConfig
|
||||
from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
|
||||
|
||||
|
||||
def register_qwen2_vla():
|
||||
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
|
@ -0,0 +1,254 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
embed_dim=1280,
|
||||
hidden_size=3584,
|
||||
hidden_act="quick_gelu",
|
||||
mlp_ratio=4,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "qwen2_vl":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class Qwen2VLAConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 29568):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 80):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 64):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2VL style configuration
|
||||
>>> configuration = Qwen2VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_vla"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
# For loading policy head
|
||||
policy_head_type="scale_dp_policy", # unet_diffusion_policy
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = Qwen2VLVisionConfig()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
self.policy_head_type = policy_head_type # for loading policy head
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,172 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import fetch_image
|
||||
|
||||
|
||||
class Qwen2VLAProcess:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer=None,
|
||||
max_seq_len=512,
|
||||
multimodal_processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
self.multimodal_processor = multimodal_processor
|
||||
|
||||
def qwen2_image_preprocess(self, each):
|
||||
ele = {}
|
||||
each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
|
||||
ele["image"] = each
|
||||
|
||||
ele["resized_height"] = each.height
|
||||
ele["resized_width"] = each.width
|
||||
each = fetch_image(ele)
|
||||
return torch.from_numpy(np.array(each))
|
||||
|
||||
def single_forward_process(self, images, raw_lang, reasoning, eval=False, use_reasoning=True):
|
||||
len_views = images.shape[0]
|
||||
messages = self.construct_chat_data(len_views, raw_lang)
|
||||
|
||||
data_dict = {"messages": messages}
|
||||
|
||||
image_data = torch.chunk(images, len_views, 0)
|
||||
|
||||
images_list = []
|
||||
|
||||
for _i, each in enumerate(image_data):
|
||||
img_pil = self.qwen2_image_preprocess(each)
|
||||
images_list.append(img_pil)
|
||||
|
||||
text = self.multimodal_processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
model_inputs = self.multimodal_processor(
|
||||
text=text,
|
||||
images=images_list,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if eval:
|
||||
new_dict = {}
|
||||
for k, v in model_inputs.items():
|
||||
if "image_grid" in k:
|
||||
new_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
new_dict[k] = v
|
||||
return new_dict
|
||||
|
||||
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||
answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||
|
||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||
output_labels = output_text["input_ids"]
|
||||
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1)
|
||||
model_inputs["attention_mask"] = torch.cat(
|
||||
(model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1
|
||||
)
|
||||
labels = torch.cat((input_labels, output_labels), dim=-1)
|
||||
|
||||
data_dict["labels"] = labels
|
||||
for k, v in model_inputs.items():
|
||||
if "image_grid" in k:
|
||||
data_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
data_dict[k] = v
|
||||
return data_dict
|
||||
|
||||
def forward(self, batch, use_reasoning=True):
|
||||
"""This is the main process function for processing vl data into Qwen2_vl format"""
|
||||
all_images = batch["images"]
|
||||
all_images = torch.einsum(
|
||||
"v b c h w -> b v c h w", all_images
|
||||
) # camera_views, batch_size, channel, height, width
|
||||
|
||||
ret_l = []
|
||||
|
||||
for idx, images in enumerate(all_images):
|
||||
raw_lang = batch["raw_langs"][idx]
|
||||
reasoning = batch["reasonings"][idx]
|
||||
ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning)
|
||||
ret_l.append(ret_dict)
|
||||
|
||||
return self.post_process(ret_l)
|
||||
|
||||
def post_process(self, instances):
|
||||
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
||||
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
||||
|
||||
image_grid_spatiotemporal = torch.stack(
|
||||
[instances["image_grid_spatiotemporal"] for instances in instances]
|
||||
)
|
||||
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
||||
pixel_values_videos = None
|
||||
video_grid_spatiotemporal = None
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
||||
labels = torch.flip(labels, dims=[1])
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
)
|
||||
input_ids = torch.flip(input_ids, dims=[1])
|
||||
b = input_ids.shape[0]
|
||||
|
||||
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
|
||||
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
|
||||
)
|
||||
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
||||
|
||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask[0],
|
||||
"labels": labels,
|
||||
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
def construct_chat_data(self, len_image, raw_lang):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [],
|
||||
},
|
||||
]
|
||||
|
||||
for _i in range(len_image):
|
||||
messages[0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": None,
|
||||
}
|
||||
)
|
||||
messages[0]["content"].append({"type": "text", "text": ""})
|
||||
messages[0]["content"][-1]["text"] = raw_lang
|
||||
|
||||
return messages
|
|
@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import dataset_to_policy_features
|
|||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
|
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "dexvla":
|
||||
from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy
|
||||
|
||||
return DexVLAPolicy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
|
@ -74,6 +79,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "dexvla":
|
||||
return DexVLAConfig(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
else:
|
||||
|
|
|
@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
|||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
dexvla = ["transformers>=4.45.2", "qwen_vl_utils==0.0.10", "timm==0.9.10"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
|
Loading…
Reference in New Issue