diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 7e158394..e2ebb9e3 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -111,6 +111,32 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) +@LRSchedulerConfig.register_subclass("constant_with_warmup") +@dataclass +class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): + """Used by DexVLA to train Stage2""" + + num_warmup_steps: int + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + def lr_lambda(current_step): + def linear_warmup_schedule(current_step): + if current_step <= 0: + return 1 / (self.num_warmup_steps + 1) + frac = 1 - current_step / self.num_warmup_steps + return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + + def constant_schedule(current_step): + return 1 + + if current_step < self.num_warmup_steps: + return linear_warmup_schedule(current_step) + + return constant_schedule(current_step) + + return LambdaLR(optimizer, lr_lambda, -1) + + def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict() write_json(state_dict, save_dir / SCHEDULER_STATE) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index b73ba5f4..00e28269 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .act.configuration_act import ACTConfig as ACTConfig +from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md new file mode 100644 index 00000000..cbf94d8b --- /dev/null +++ b/lerobot/common/policies/dexvla/README.md @@ -0,0 +1,140 @@ +

+DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning

+ +This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA). +This is [project website](https://dex-vla.github.io/). + +## Dataset +### Data format +DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras. + +⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations. +So you have to add sub-step reasoning in your data for training. + +Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation. +For example, if the episode is 10 steps. The length of this list should be 10 as well. And it may looks like: +~~~python +reasoning = [ + "This is step 1.", + "This is step 1.", + "This is step 2.", + "This is step 2.", + ... + "This is step 4.", +] +~~~ + +Besides, your data should include another key ``action_is_pad`` which is a bool mask indicating whether this action chunk is padded. +Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5. +And the mask looks like: +~~~python +The 6th chunk: [false, false, false, false, true] +The 7th chunk: [false, false, false, true, true] +The 8th chunk: [false, false, true, true, true] +The 9th chunk: [false, true, true, true, true] +~~~ + +### Training Data for DexVLA +The pretraining dataset comprises approximately 100 hours of collected data by ourselves. The dataset mainly including four embodiments which are: moblie Agilex Aloha, single Franka Emika and single UR5e. +We haven't use any public dataset such as Open-X or DROID. + +## 🤗Download Pretrained Weights +### Download official Qwen2_VL weights +We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. +The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities +for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed +in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link: + +| Model | Link | +|---------------------|----------------------------------------------------------------| +| Qwen2-VL (~2B) | [huggingface](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) | + +**❗❗** After downloading the standard weights, you have to replace the official "config.json" +with our ["config.json"](https://github.com/juruobenruo/DexVLA/blob/main/docs/config.json) designed for VLA. +### Download our pretrained ScaleDP-H weights(Stage 1) +We released our pretrained weights of ScaleDP-H which is trained after Stage1. Now you can download the weights and directly finetuning your data on Stage 2. + +| Model | Link | +|-------------------|----------------------------------------------------------------| +| ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) | +| ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) | + +**❗❗**After downloading the weights, you have to transform it into ``safetensors`` format, you can simply run this code: +~~~python +import torch +from safetensors.torch import save_file +path = "/path/to/open_scale_dp_l_backbone.ckpt" +checkpoint = torch.load(path, map_location=torch.device('cpu'))['nets']['nets'] + +# Save the weights in safetensors format +safetensors_path = "/path/to/open_scale_dp_l_backbone.safetensors" +save_file(checkpoint, safetensors_path) +print(f"Converted {path} to {safetensors_path}") +pass + +~~~ + +## 🦾Train +We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3. + +### Training Stage 2 +~~~shell +python lerobot/scripts/train.py \ +--policy.type dexvla \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \ +--policy.policy_head_size 'scaledp_h' \ +--policy.training_stage 2 \ +--dataset.repo_i lerobot/aloha_mobile_chair \ +--policy.using_film true \ +--output_dir /path/to/output \ +--steps 10000 \ +--save_freq 1000 \ +--optimizer_lr 2e-5 +~~~ + +### Training Stage 3 +Stage3 can be viewed as continual training on specific dexterous tasks like laundry folding which is same as PI0. So stage3 is trained based on stage2. +~~~shell +python lerobot/scripts/train.py \ +--policy.type dexvla \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--.pretrained_path /path/to/pretrained/stage2/weights \ +--policy.policy_head_size 'scaledp_h' \ +--policy.training_stage 3 \ +--dataset.repo_i lerobot/aloha_mobile_chair \ +--batch_size 2 \ +--policy.using_film true \ +--output_dir /path/to/output \ +--steps 10000 \ +--save_freq 1000 \ +--optimizer_lr 2e-5 +~~~ + +### Training Time +Original DexVLA is trained on 8 x H100 GPUs. And the training time for each stage is listed as follows: + +| Stage | Batch Size(each gpu) | Steps | Time(hour) | +|--------|----------------------|--------|------------| +| Stage1 | 32 | 60000 | 30 | +| Stage2 | 12 | 100000 | 30 | +| Stage3 | 12 | 60000 | 18 | + + +## Evaluation +### Evaluation Script +You can evaluate dexvla by following scripts. +~~~shell +python lerobot/scripts/eval.py \ +--policy.type dexvla \ +--policy.pretrained_path /path/to/pretrained/stage2/or/stage3/weights \ +--env.type aloha \ +--env.episode_length 5 \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--env.task AlohaInsertion-v0 \ +--eval.n_episodes 1 \ +--eval.batch_size 1 +~~~ + +### Inference Speed +Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz. diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py new file mode 100644 index 00000000..96a3944b --- /dev/null +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen2VL model configuration""" + +from dataclasses import dataclass, field +from typing import Tuple + +from transformers import AutoConfig +from transformers.utils import logging + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + ConstantWithWarmupSchedulerConfig, + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + +from .policy_heads import register_policy_heads +from .qwe2_vla import register_qwen2_vla + +logger = logging.get_logger(__name__) +register_policy_heads() +register_qwen2_vla() + + +@PreTrainedConfig.register_subclass("dexvla") +@dataclass +class DexVLAConfig(PreTrainedConfig): + # For loading policy head + policy_head_type: str = "scale_dp_policy" + policy_head_size: str = "scaledp_l" + action_dim: int = 14 + state_dim: int = 14 + chunk_size: int = 50 + n_action_steps: int = 50 + n_obs_steps: int = 1 + + device: str = "cuda" + + hidden_size: int = 1536 + qwen2_vl_path: str = ( + None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl + ) + + pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 + pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) + + training_stage: int = 2 # specific training stage, [2, 3] + using_film: bool = True + llm_loss_weight: float = 1.0 + with_llm_head: bool = True + using_reasoning: bool = True + resize_size: tuple = (240, 320) + # Training presets + optimizer_lr: float = 2e-5 + optimizer_betas: Tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + + scheduler_warmup_steps: int = 2_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + # "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + def __post_init__(self): + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + if self.using_reasoning: + assert self.using_film, "using_reasoning requires `using_film=True`" + assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`" + print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.") + else: + print( + "Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'." + ) + + if self.qwen2_vl_path is None: + raise ValueError( + "DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'." + ) + + if self.policy_head_type == "scale_dp_policy": + self.policy_head_config = AutoConfig.for_model( + model_type=self.policy_head_type, + model_size=self.policy_head_size, + cond_dim=self.hidden_size, + action_dim=self.action_dim, + prediction_horizon=self.chunk_size, + state_dim=self.state_dim, + ) + elif self.policy_head_type == "unet_diffusion": + self.policy_head_config = AutoConfig.for_model( + model_type=self.policy_head_type, + global_cond_dim=self.hidden_size, + action_dim=self.action_dim, + state_dim=self.state_dim, + ) + else: + raise ValueError(f"Policy head type {self.policy_head_type} not supported") + + if self.training_stage not in [2, 3]: + raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.") + + self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) + + def validate_features(self) -> None: + # TODO: implement value error + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + # for i in range(self.empty_cameras): + # key = f"observation.images.empty_camera_{i}" + # empty_camera = PolicyFeature( + # type=FeatureType.VISUAL, + # shape=(3, 480, 640), + # ) + # self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self): + if self.training_stage == 3: + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + else: + return ConstantWithWarmupSchedulerConfig( + num_warmup_steps=self.scheduler_warmup_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/dexvla/fusion_modules.py b/lerobot/common/policies/dexvla/fusion_modules.py new file mode 100644 index 00000000..39bbc57f --- /dev/null +++ b/lerobot/common/policies/dexvla/fusion_modules.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class ActionProjector(nn.Module): + def __init__(self, in_dim, out_dim=1024): + super().__init__() + self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + self.mlps = nn.ModuleList( + [ + # nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.Dropout(0.0), + ] + ) + + def forward(self, x): + x = self.global_1d_pool(x.permute(1, 0)).permute(1, 0) + for mlp in self.mlps: + x = mlp(x) + return x + + +class FiLM(nn.Module): + def __init__(self, feature_dim, condition_dim): + super().__init__() + self.scale_fc = nn.Linear(condition_dim, feature_dim) + self.shift_fc = nn.Linear(condition_dim, feature_dim) + + nn.init.zeros_(self.scale_fc.weight) + nn.init.zeros_(self.scale_fc.bias) + nn.init.zeros_(self.shift_fc.weight) + nn.init.zeros_(self.shift_fc.bias) + + def forward(self, x, condition): + # calculate scale and shift + scale = self.scale_fc(condition) + shift = self.shift_fc(condition) + + # film + return x * (1 + scale) + shift diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py new file mode 100644 index 00000000..e1133df8 --- /dev/null +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import torch +import torchvision.transforms as transforms +from safetensors.torch import load_file +from torch import Tensor +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer + +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig +from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy + + +class DexVLAPolicy(PreTrainedPolicy): + """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" + + config_class = DexVLAConfig + name = "dexvla" + + def __init__( + self, + config: DexVLAConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]: + setattr(config.qwen2_vla_config, k, config.__dict__[k]) + + # if self.config.training_stage == 2: + # self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) + model_base = self.config.qwen2_vl_path + self.model = AutoModelForCausalLM.from_pretrained( + model_base, + config=config.qwen2_vla_config, + trust_remote_code=True, + _fast_init=False, + # attn_implementation="flash_attention_2", + ).to(device="cuda", dtype=torch.bfloat16) + + if self.config.pretrained_scaledp_path is not None: + print( + "\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ) + pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path) + + keys_to_del_dit = [] + pretrain_scaledp_weights = { + k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items() + } + for k in pretrain_scaledp_weights: + if "noise_pred" not in k: # del weights of vision backbones + keys_to_del_dit.append(k) + if "cond_obs_emb" in k: + keys_to_del_dit.append(k) + for k in keys_to_del_dit: + del pretrain_scaledp_weights[k] + pretrain_scaledp_weights = { + k[15:] if k.startswith("noise_pred_net.") else k: v + for k, v in pretrain_scaledp_weights.items() + } + + self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False) + + self.model.requires_grad_(False) + self.model.policy_head.requires_grad_(True) + self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path) + self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path) + self.vla_processor = Qwen2VLAProcess( + tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor + ) # process the input data into VLM format + + self.resize_size = self.config.resize_size + ratio = 0.95 + self.transformations = [ + transforms.Resize(size=self.resize_size, antialias=True), + transforms.RandomCrop(size=[int(self.resize_size[0] * ratio), int(self.resize_size[1] * ratio)]), + transforms.Resize(self.resize_size, antialias=True), + transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False), + transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), # , hue=0.08) + ] + + self.reset() + + def process_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Applying DexVLA preprocessing to original data. Including resizing images. Scaling the range of actions, states.""" + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + present_img_keys = [key for key in self.config.image_features if key in batch] + task_descs = batch["task"] + try: + reasonings = batch["reasoning"] + except KeyError: + reasonings = ["None."] * len(task_descs) + + pass + is_pad = batch["action_is_pad"] + all_cam_images = [] + for k in present_img_keys: + all_cam_images.append(batch[k]) + + # construct observations, and scale 0-1 to 0-255 + image_data = torch.stack(all_cam_images) * 255 + image_data = image_data.to(dtype=torch.uint8) + # construct observations + qpos_data = batch["observation.state"].float() + action_data = batch["action"].float() + + orig_shape = image_data.shape + image_data = image_data.view(-1, *orig_shape[2:]) + + for transform in self.transformations: + image_data = transform(image_data) + + image_data = image_data.view(*orig_shape[:3], *self.resize_size) + + vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings} + # processing vl_data into qwen2_vl format + vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning) + vla_inputs["states"] = qpos_data + vla_inputs["is_pad"] = is_pad + vla_inputs["actions"] = action_data + return vla_inputs + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: + processed_batch = self.process_batch(batch) + + ret = self.model.forward(**processed_batch) + loss_dict = ret["loss"] + loss = loss_dict["loss"].mean() + return loss, loss_dict + + def dexvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + tokenizer=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_spatiotemporal=None, + ): + input_ids = input_ids.to("cuda") + with torch.inference_mode(): + outputs = self.model.generate( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_spatiotemporal=image_grid_spatiotemporal, + is_eval=is_eval, + num_beams=1, + do_sample=False, + temperature=0.2, + max_new_tokens=256, + eos_token_id=tokenizer.eos_token_id, # End of sequence token + pad_token_id=tokenizer.eos_token_id, # Pad token + use_cache=True, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + output_ids = outputs.sequences + # last_hidden_states = outputs.hidden_states[-2][-1] + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids") + outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0] + + outputs_text = outputs_text.strip() + last_hidden_states = [each[-1] for each in outputs.hidden_states] # all hidden states + all_hidden_states = torch.cat(last_hidden_states, dim=1) + + action_hidden_states = None + labels_input = torch.ones((1, input_token_len)) * -100 + labels_output = torch.ones((1, output_ids.shape[1] - input_token_len)) + labels = torch.cat([labels_input, labels_output], dim=1) + + if self.model.using_film: + action_hidden_states = self.model.film_forward( + labels=labels, + input_ids=output_ids, + hidden_states=torch.cat(last_hidden_states, dim=1), + ) + + action = self.model.policy_head( + actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) + return action, outputs_text + + def tinyvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_spatiotemporal=None, + ): + input_ids = input_ids.to("cuda") + with torch.inference_mode(): + all_hidden_states = self.model.forward( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_spatiotemporal=image_grid_spatiotemporal, + is_eval=is_eval, + tinyvla=True, + ) + + all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) + + action = self.model.policy_head( + actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) + return action, "tinyvla generates no reasoning" + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + batch = self.normalize_inputs(batch) + + if len(self._action_queue) == 0: + present_img_keys = [key for key in self.config.image_features if key in batch] + try: + task_descs = batch["task"] + except KeyError: + task_descs = " " + print("No task descriptions found for this task") + + all_cam_images = [] + for k in present_img_keys: + all_cam_images.append(batch[k]) + + # construct observations, and scale 0-1 to 0-255 + image_data = torch.stack(all_cam_images) * 255 + image_data = image_data.to(dtype=torch.uint8) + # construct observations + qpos_data = batch["observation.state"].float() + + image_data = image_data.squeeze(0) + + for transform in self.transformations: + image_data = transform(image_data) + + # processing vl_data into qwen2_vl format + vla_inputs = self.vla_processor.single_forward_process( + images=image_data, raw_lang=task_descs, reasoning=None, eval=True + ) + vla_inputs["states"] = qpos_data + + if self.config.using_film and self.config.with_llm_head: # dexvla + all_actions, outputs = self.dexvla_predict_action( + **vla_inputs, is_eval=True, tokenizer=self.tokenizer + ) + else: # tinyvla + all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True) + + actions = self.unnormalize_outputs({"action": all_actions})["action"] + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() diff --git a/lerobot/common/policies/dexvla/policy_heads/__init__.py b/lerobot/common/policies/dexvla/policy_heads/__init__.py new file mode 100644 index 00000000..f3b6a169 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/__init__.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoConfig, AutoModel + +from .configuration_scaledp import ScaleDPPolicyConfig +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from .modeling_scaledp import ScaleDP +from .modeling_unet_diffusion import ConditionalUnet1D + + +def register_policy_heads(): + AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) + AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) + AutoModel.register(ScaleDPPolicyConfig, ScaleDP) + AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py new file mode 100644 index 00000000..e2d71cea --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +MODEL_STRUCTURE = { + "scaledp_h": { + "depth": 32, + "n_emb": 1280, + "num_heads": 16, + }, + "scaledp_l": { + "depth": 24, + "n_emb": 1024, + "num_heads": 16, + }, # 400M +} + + +class ScaleDPPolicyConfig(PretrainedConfig): + """ + Configuration for ScaleDP policy head + """ + + model_type = "scale_dp_policy" + + def __init__( + self, + eval: bool = False, + action_dim: int = 14, # action dim + # output_dim: int = 14, # action dim + cond_dim: int = 1536, # the input dim of the condition + state_dim: int = 14, # the input dim of the state + prediction_horizon: int = 16, # horizon + n_obs_steps: int = 2, # number of observation steps + depth: int = 28, # number of DiT blocks + n_emb: int = 256, # embedding size + num_heads: int = 16, + mlp_ratio: int = 4.0, + time_as_cond: bool = True, + obs_as_cond: bool = True, + learn_sigma: bool = False, + model_size: str = "none", + num_inference_timesteps: int = 10, + noise_samples: int = 1, + num_train_timesteps: int = 100, + **kwargs, + ): + if model_size != "none": + depth = MODEL_STRUCTURE[model_size]["depth"] + n_emb = MODEL_STRUCTURE[model_size]["n_emb"] + num_heads = MODEL_STRUCTURE[model_size]["num_heads"] + else: + # raise ValueError("model_size show not be 'none'") + pass + # print("model_size should not be 'none'") + self.eval = eval + + self.input_dim = action_dim + self.output_dim = action_dim + self.prediction_horizon = prediction_horizon + + self.cond_dim = cond_dim + self.state_dim = state_dim + + self.n_obs_steps = n_obs_steps + self.depth = depth + self.n_emb = n_emb + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.time_as_cond = time_as_cond + self.obs_as_cond = obs_as_cond + self.learn_sigma = learn_sigma + + self.num_inference_timesteps = num_inference_timesteps + self.num_queries = prediction_horizon + self.noise_samples = noise_samples + self.num_train_timesteps = num_train_timesteps + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "llava_pythia": + config_dict = config_dict["action_head"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py new file mode 100644 index 00000000..b7eb046e --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class UnetDiffusionPolicyConfig(PretrainedConfig): + """ + Configuration for dit diffusion policy head + """ + + model_type = "unet_diffusion_policy" + + def __init__( + self, + action_dim=10, + global_cond_dim=2048, + diffusion_step_embed_dim=256, + down_dims=None, + kernel_size=5, + n_groups=8, + state_dim=7, + prediction_horizon=16, + noise_samples=1, + num_inference_timesteps=10, + num_train_timesteps=100, + **kwargs, + ): + if down_dims is None: + down_dims = [256, 512, 1024] + self.input_dim = action_dim + self.noise_samples = noise_samples + self.prediction_horizon = prediction_horizon + self.num_inference_timesteps = num_inference_timesteps + self.global_cond_dim = global_cond_dim + self.diffusion_step_embed_dim = diffusion_step_embed_dim + self.down_dims = down_dims + self.kernel_size = kernel_size + self.n_groups = n_groups + self.state_dim = state_dim + self.num_train_timesteps = num_train_timesteps + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "llava_pythia": + config_dict = config_dict["action_head"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py new file mode 100644 index 00000000..62ff0587 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -0,0 +1,561 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as func +import torch.utils.checkpoint +from timm.models.vision_transformer import Mlp, use_fused_attn +from torch.jit import Final +from transformers.modeling_utils import PreTrainedModel + +from .configuration_scaledp import ScaleDPPolicyConfig + +_logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor: + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = func.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + # attn = q @ k.transpose(-2, -1) + # if attn_mask is not None: + # attn += attn_mask + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + # Add attention mask if provided + if attn_mask is not None: + attn_scores += attn_mask + + # Apply softmax to get attention weights (softmax is applied along the last dimension) + attn_weights = func.softmax(attn_scores, dim=-1) + + # Dropout on attention weights (if dropout is used) + attn_weights = self.attn_drop(attn_weights) + + # Apply attention weights to value tensor (V) + x = torch.matmul(attn_weights, v) + + x = x.transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +logger = logging.getLogger(__name__) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.bfloat16) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding.to(dtype=torch.bfloat16) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +################################################################################# +# Core ScaleDP Model # +################################################################################# + + +class ScaleDPBlock(nn.Module): + """ + A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + def approx_gelu(): + return nn.GELU(approximate="tanh") + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c, attn_mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk( + 6, dim=1 + ) + x = x + gate_msa.unsqueeze(1) * self.attn( + modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask + ) # norm, scale&shift, attn, scale, + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of ScaleDP. + """ + + def __init__(self, hidden_size, output_dim): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, output_dim, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class ScaleDP(PreTrainedModel): + """ + Diffusion models with a Transformer backbone. + """ + + config_class = ScaleDPPolicyConfig + + def __init__( + self, + config: ScaleDPPolicyConfig, + ): + super().__init__(config) + # compute number of tokens for main trunk and conScaleDPion encoder + if config.n_obs_steps is None: + config.n_obs_steps = config.prediction_horizon + t = config.prediction_horizon + t_cond = 1 + if not config.time_as_cond: + t += 1 + t_cond -= 1 + obs_as_cond = config.cond_dim > 0 + if obs_as_cond: + assert config.time_as_cond + t_cond += config.n_obs_steps + + # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) + self.combine = nn.Sequential( + nn.Linear(config.cond_dim + config.state_dim, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, config.cond_dim), + ) + self.learn_sigma = config.learn_sigma + self.input_dim = config.input_dim + self.output_dim = config.output_dim * 2 if config.learn_sigma else config.output_dim + self.num_heads = config.num_heads + + self.x_embedder = nn.Linear(config.input_dim, config.n_emb) + self.t_embedder = TimestepEmbedder(config.n_emb) + self.cond_obs_emb = None + if obs_as_cond: + self.cond_obs_emb = nn.Linear(config.cond_dim, config.n_emb) + + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb)) + + self.blocks = nn.ModuleList( + [ + ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) + for _ in range(config.depth) + ] + ) + self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim) + # self.initialize_weights() + # constants + self.t = t + self.t_cond = t_cond + self.prediction_horizon = config.prediction_horizon + self.time_as_cond = config.time_as_cond + self.action_dim = config.output_dim + self.obs_as_cond = obs_as_cond + logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters())) + + from diffusers.schedulers.scheduling_ddim import DDIMScheduler + + self.num_inference_timesteps = config.num_inference_timesteps + # self.proj_to_action = nn.Identity() + self.noise_scheduler = DDIMScheduler( + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + set_alpha_to_one=True, + steps_offset=0, + prediction_type="epsilon", + ) + self.num_queries = config.num_queries # 16 + self.noise_samples = config.noise_samples # 1 + # self.num_inference_timesteps = config.num_inference_timesteps # 100 + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.normal_(self.pos_embed, mean=0.0, std=0.02) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.cond_obs_emb.weight, mean=0.0, std=0.02) + nn.init.constant_(self.cond_obs_emb.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in ScaleDP blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def get_optim_groups(self, weight_decay: float = 1e-3): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the models into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, Attention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _p in m.named_parameters(): + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.startswith("bias"): + # MultiheadAttention bias starts with "bias" + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = dict(self.named_parameters()) + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(decay)], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(no_decay)], + "weight_decay": 0.0, + }, + ] + return optim_groups + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + + def forward(self, actions, hidden_states, states, is_pad): + """ + Forward pass for the diffusion head. + :param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024 + :param states: robot states, shape [b, D] + :return: loss + """ + if actions is not None: # training time + b = actions.size(0) + actions = actions[:, : self.num_queries] + is_pad = is_pad[:, : self.num_queries] + num_noise_samples = self.noise_samples + # sample noise to add to actions + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, b, Ta, D(1, 2, 16, 14) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device + ).long() + + timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) + + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * b, Ta, action_dim] + + noisy_actions = noisy_actions.to(dtype=actions.dtype) + assert hidden_states.ndim == 3 + + hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) + timesteps = timesteps.repeat(num_noise_samples) + is_pad = is_pad.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) + + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) + noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") + loss = (loss * ~is_pad.unsqueeze(-1)).mean() + # loss_dict['loss'] = loss + return {"loss": loss} + # return loss + else: # inference time + b = 1 + tp = self.num_queries + action_dim = self.action_dim + + # initialize action from Gaussian noise + noisy_action = torch.randn((b, tp, action_dim)).cuda() + + naction = noisy_action.to(dtype=hidden_states.dtype) + # init scheduler + self.noise_scheduler.set_timesteps(self.num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states) + + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, timestep=k, sample=naction + ).prev_sample + + return naction + + def model_forward(self, x, t, global_cond, states): + """ + Forward pass of ScaleDP. + x: (N, T, input_dim) noisy actions + t: (N,) tensor of diffusion timesteps + global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings + """ + global_cond = global_cond.squeeze(1) + global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond + global_cond = self.combine(global_cond) + + if not torch.is_tensor(t): + t = torch.tensor([t], dtype=torch.long, device=x.device) + elif torch.is_tensor(t) and len(t.shape) == 0: + t = t[None].to(x.device) + t = t.expand(t.shape[0]) + + x = self.x_embedder(x) + self.pos_embed.to( + device=x.device, dtype=x.dtype + ) # (N, T, D), where T = prediction_horizon + t = self.t_embedder(t) # (N, D) + if self.obs_as_cond: + global_cond = self.cond_obs_emb(global_cond) # (N, D) + # c = t + global_cond.sum(dim=1) # (N, D) + c = t + global_cond # (N, D) + for block in self.blocks: + # x = block(x, c, attn_mask=self.mask) # (N, T, D) + x = block(x, c, attn_mask=None) # (N, T, D) + x = self.final_layer(x, c) # (N, T, output_dim) + return x + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# ScaleDP Configs # +################################################################################# + + +def scaledp_h(**kwargs): + return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs) + + +def scaledp_l(**kwargs): + return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py new file mode 100644 index 00000000..0dea2e90 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import Union + +import torch +import torch.nn as nn + +# requires diffusers==0.11.1 +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from transformers.modeling_utils import PreTrainedModel + +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig + +# =================== UNet for Diffusion ============== + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim, dtype): + super().__init__() + self.dim = dim + self.dtype = dtype + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class ConditionalResidualBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): + super().__init__() + + self.blocks = nn.ModuleList( + [ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ] + ) + + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels * 2 + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) + ) + + # make sure dimensions compatible + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x, cond): + """ + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + + embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) + scale = embed[:, 0, ...] + bias = embed[:, 1, ...] + out = scale * out + bias + + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(PreTrainedModel): + _no_split_modules = ["mid_modules", "down_modules", "up_modules"] + + config_class = UnetDiffusionPolicyConfig + + def __init__(self, config: UnetDiffusionPolicyConfig): + """ + input_dim: Dim of actions. + global_cond_dim: Dim of global conditioning applied with FiLM + in addition to diffusion step embedding. This is usually obs_horizon * obs_dim + diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k + down_dims: Channel size for each UNet level. + The length of this array determines number of levels. + kernel_size: Conv kernel size + n_groups: Number of groups for GroupNorm + """ + + super().__init__(config) + all_dims = [config.input_dim] + list(config.down_dims) + start_dim = config.down_dims[0] + + self.num_queries = config.prediction_horizon + self.noise_samples = config.noise_samples + # self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + # self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim) + self.norm_after_pool = nn.LayerNorm(config.global_cond_dim) + self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim) + dsed = config.diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed, torch.bfloat16), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + config.global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList( + [ + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ] + ) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Downsample1d(dim_out) if not is_last else nn.Identity(), + ] + ) + ) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_out * 2, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_in, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Upsample1d(dim_in) if not is_last else nn.Identity(), + ] + ) + ) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size), + nn.Conv1d(start_dim, config.input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters()))) + + self.num_inference_timesteps = config.num_inference_timesteps + # self.proj_to_action = nn.Identity() + self.noise_scheduler = DDIMScheduler( + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + set_alpha_to_one=True, + steps_offset=0, + prediction_type="epsilon", + ) + + # self.num_inference_timesteps = config.num_inference_timesteps # 100 + + def forward(self, actions, hidden_states, states, is_pad): + """ + Forward pass for the diffusion head. + :param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024 + :param states: robot states, shape [b, D] + :return: loss + """ + if actions is not None: # training time + b = actions.size(0) + actions = copy.deepcopy(actions[:, : self.num_queries]) + is_pad = copy.deepcopy(is_pad[:, : self.num_queries]) + num_noise_samples = self.noise_samples + # sample noise to add to actions + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, b, Ta, D + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device + ).long() + + timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) + + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * b, Ta, action_dim] + + noisy_actions = noisy_actions.to(dtype=actions.dtype) + assert hidden_states.ndim == 3 + + hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) + timesteps = timesteps.repeat(num_noise_samples) + is_pad = is_pad.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) + + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) + noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") + loss = (loss * ~is_pad.unsqueeze(-1)).mean() + # loss_dict['loss'] = loss + return {"loss": loss} + # return loss + else: # inference time + b = 1 + tp = self.num_queries + action_dim = 14 + + # initialize action from Gaussian noise + noisy_action = torch.randn((b, tp, action_dim)).cuda() + + naction = noisy_action.to(dtype=hidden_states.dtype) + # init scheduler + self.noise_scheduler.set_timesteps(self.num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states) + + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, timestep=k, sample=naction + ).prev_sample + + return naction + + def model_forward( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None + ): + """ + x: (b,T,input_dim) + timestep: (b,) or int, diffusion step + global_cond: (b,global_cond_dim) + output: (b,T,input_dim) + """ + # (b,t,c) + sample = sample.moveaxis(-1, -2) + # (b,c,t) + # global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) + global_cond = global_cond.squeeze(1) + + global_cond = self.norm_after_pool(global_cond) + global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond + global_cond = self.combine(global_cond) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + + if global_cond is not None: + global_feature = torch.cat([global_feature, global_cond], axis=-1) + + x = sample + h = [] + for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + # (b,c,t) + x = x.moveaxis(-1, -2) + # (b,t,c) + return x diff --git a/lerobot/common/policies/dexvla/qwe2_vla/__init__.py b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py new file mode 100644 index 00000000..35627635 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoConfig, AutoModelForCausalLM + +from .configuration_qwen2_vla import Qwen2VLAConfig +from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA + + +def register_qwen2_vla(): + AutoConfig.register("qwen2_vla", Qwen2VLAConfig) + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py new file mode 100644 index 00000000..1a3e7411 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig + + >>> # Initializing a Qwen2VL style configuration + >>> configuration = Qwen2VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_vla" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + # For loading policy head + policy_head_type="scale_dp_policy", # unet_diffusion_policy + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.policy_head_type = policy_head_type # for loading policy head + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py new file mode 100644 index 00000000..4b656354 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -0,0 +1,2046 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2-VL model.""" + +import gc +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as func +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss, LayerNorm +from transformers import AutoConfig, AutoModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from lerobot.common.policies.dexvla.fusion_modules import ActionProjector, FiLM + +from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2VLConfig" + + +@dataclass +class Qwen2VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2VLRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Qwen2VLAConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for spatiotemporal grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = LayerNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class VisionMlp(nn.Module): + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> torch.Tensor: + return self.fc2(self.act(self.fc1(x))) + + +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = func.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) + self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.embed_dim, num_heads=config.num_heads + ) + self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(torch.bfloat16), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2VLAConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2VLRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + 1 + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLFlashAttention2(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1 + ) + + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLSdpaAttention(Qwen2VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = bool(causal_mask is None and q_len > 1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": Qwen2VLSdpaAttention, +} + + +class Qwen2VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2VLAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2VLAConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock", "policy_head"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): + config_class = Qwen2VLVisionConfig + _no_split_modules = ["Qwen2VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = PatchMerger( + dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size + ) + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_spatiotemporal): + pos_ids = [] + for t, h, w in grid_spatiotemporal: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_spatiotemporal[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, grid_spatiotemporal: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal) + + cu_seqlens = torch.repeat_interleave( + grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = func.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + return self.merger(hidden_states) + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLModel(Qwen2VLPreTrainedModel): + def __init__(self, config: Qwen2VLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + and AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ) + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2VLAConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None and ( + not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length + ): + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~policy_heads.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.model = Qwen2VLModel(config) + self.vocab_size = config.vocab_size + self.with_llm_head = config.with_llm_head + + self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.using_film = config.using_film + + self.llm_loss_weight = config.llm_loss_weight + + if isinstance(config.policy_head_config, dict): + config.policy_head_config = AutoConfig.for_model(**config.policy_head_config) + self.policy_head = AutoModel.from_config(config=config.policy_head_config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + if config.policy_head_config.model_type == "scale_dp_policy": + self.policy_head.init_weights() + self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + + if self.using_film: + # Initialize projection layers and condition modulation layers + self.reasoning_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + self.reasoning_film = FiLM(feature_dim=config.hidden_size, condition_dim=config.hidden_size) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_spatiotemporal: Optional[torch.LongTensor] = None, + video_grid_spatiotemporal: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if image_grid_spatiotemporal is not None or video_grid_spatiotemporal is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_spatiotemporal[image_index][0], + image_grid_spatiotemporal[image_index][1], + image_grid_spatiotemporal[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_spatiotemporal[video_index][0], + video_grid_spatiotemporal[video_index][1], + video_grid_spatiotemporal[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + ) + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + if getattr(outputs, "rope_deltas", None) is not None: + model_kwargs["rope_deltas"] = outputs.rope_deltas + + return model_kwargs + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_spatiotemporal: Optional[torch.LongTensor] = None, + video_grid_spatiotemporal: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + actions: Optional[torch.LongTensor] = None, + states: Optional[torch.FloatTensor] = None, + is_pad: bool = False, + is_eval: bool = False, + tinyvla: bool = False, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + self.computed_type = torch.bfloat16 + input_ids = input_ids.to("cuda") + attention_mask = attention_mask.to("cuda") + if not is_eval: + labels = labels.to("cuda") + actions = actions.to(dtype=self.computed_type, device="cuda") + states = states.to(dtype=self.computed_type, device="cuda") + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask + ) + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda") + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_spatiotemporal=image_grid_spatiotemporal) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_spatiotemporal=video_grid_spatiotemporal) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if tinyvla: # dex-vla supports tinyvla-style VLA + return hidden_states + + if self.with_llm_head: + logits = self.lm_head(hidden_states) + logits = logits.float() + else: + logits = None + self.llm_head = None + + llm_loss = None + # cross-entropy loss for VLM + if labels is not None and self.with_llm_head: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + llm_loss = loss_fct(shift_logits, shift_labels) + + # for evaluation + if is_eval: + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + if self.using_film: + action_hidden_states = self.film_forward( + labels=labels, input_ids=input_ids, hidden_states=hidden_states + ) + else: # tinyvla + action_hidden_states = hidden_states + + ret = self.policy_head( + actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad + ) + + if self.with_llm_head: + loss = { + "loss": ret["loss"] + self.llm_loss_weight * llm_loss, + "llm_loss": llm_loss, + "action_loss": ret["loss"], + } + else: + loss = { + "loss": ret["loss"], + "llm_loss": (torch.ones(1) * (-100)).to(ret["loss"].dtype).squeeze(0), + "action_loss": ret["loss"], + } + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + torch.cuda.empty_cache() + gc.collect() + del input_ids + del attention_mask + del position_ids + del past_key_values + del inputs_embeds + del labels + del pixel_values + del image_grid_spatiotemporal + del actions + del states + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def film_forward(self, labels, input_ids, hidden_states): + """ + Perform the forward pass for the film module. + """ + inputs_index = labels[:, :] == -100 + inputs_index = inputs_index.int() + + xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:]) + indexes = torch.argmax((xor_array != 0).float(), dim=1) + input_embeddings = [] + reasoning_embeddings = [] + identity = [] + for i in range(indexes.shape[0]): + end = indexes[i] + 1 + temp = input_ids[i] == 151643 # pad token id for qwen2_vl + start = sum(temp.int()) + input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :])) + identity.append(torch.mean(hidden_states[i, start:end, :], dim=0)) + + reasoning_embeddings.append(self.reasoning_action_proj(hidden_states[i, end:, :])) + input_embeddings = torch.cat(input_embeddings, dim=0) + reasoning_embeddings = torch.cat(reasoning_embeddings, dim=0) + identity = torch.stack(identity) + + action_hidden_states = self.reasoning_film(input_embeddings, reasoning_embeddings).unsqueeze(1) + + action_hidden_states = action_hidden_states + identity.unsqueeze(1) + return action_hidden_states + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_spatiotemporal=None, + video_grid_spatiotemporal=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas") + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas + if cache_position is not None and rope_deltas is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_spatiotemporal": image_grid_spatiotemporal, + "video_grid_spatiotemporal": video_grid_spatiotemporal, + "rope_deltas": rope_deltas, + } + ) + model_inputs.update(kwargs) + return model_inputs diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py new file mode 100644 index 00000000..7af0aa05 --- /dev/null +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python + +# Copyright 2025 DexVLA Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image +from qwen_vl_utils import fetch_image + + +class Qwen2VLAProcess: + def __init__( + self, + tokenizer=None, + max_seq_len=512, + multimodal_processor=None, + ): + super().__init__() + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.multimodal_processor = multimodal_processor + + def qwen2_image_preprocess(self, each): + ele = {} + each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)) + ele["image"] = each + + ele["resized_height"] = each.height + ele["resized_width"] = each.width + each = fetch_image(ele) + return torch.from_numpy(np.array(each)) + + def single_forward_process(self, images, raw_lang, reasoning, eval=False, use_reasoning=True): + len_views = images.shape[0] + messages = self.construct_chat_data(len_views, raw_lang) + + data_dict = {"messages": messages} + + image_data = torch.chunk(images, len_views, 0) + + images_list = [] + + for _i, each in enumerate(image_data): + img_pil = self.qwen2_image_preprocess(each) + images_list.append(img_pil) + + text = self.multimodal_processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + model_inputs = self.multimodal_processor( + text=text, + images=images_list, + videos=None, + padding=True, + return_tensors="pt", + ) + + if eval: + new_dict = {} + for k, v in model_inputs.items(): + if "image_grid" in k: + new_dict["image_grid_spatiotemporal"] = v + else: + new_dict[k] = v + return new_dict + + input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 + answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>" + + output_text = self.tokenizer(answer, padding=True, return_tensors="pt") + output_labels = output_text["input_ids"] + model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1) + model_inputs["attention_mask"] = torch.cat( + (model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1 + ) + labels = torch.cat((input_labels, output_labels), dim=-1) + + data_dict["labels"] = labels + for k, v in model_inputs.items(): + if "image_grid" in k: + data_dict["image_grid_spatiotemporal"] = v + else: + data_dict[k] = v + return data_dict + + def forward(self, batch, use_reasoning=True): + """This is the main process function for processing vl data into Qwen2_vl format""" + all_images = batch["images"] + all_images = torch.einsum( + "v b c h w -> b v c h w", all_images + ) # camera_views, batch_size, channel, height, width + + ret_l = [] + + for idx, images in enumerate(all_images): + raw_lang = batch["raw_langs"][idx] + reasoning = batch["reasonings"][idx] + ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning) + ret_l.append(ret_dict) + + return self.post_process(ret_l) + + def post_process(self, instances): + input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] + labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] + + image_grid_spatiotemporal = torch.stack( + [instances["image_grid_spatiotemporal"] for instances in instances] + ) + pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) + pixel_values_videos = None + video_grid_spatiotemporal = None + + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) + labels = torch.flip(labels, dims=[1]) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + input_ids = torch.flip(input_ids, dims=[1]) + b = input_ids.shape[0] + + image_grid_spatiotemporal = image_grid_spatiotemporal.reshape( + b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2] + ) + pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) + + attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask[0], + "labels": labels, + "image_grid_spatiotemporal": image_grid_spatiotemporal, + "pixel_values_videos": pixel_values_videos, + "video_grid_spatiotemporal": video_grid_spatiotemporal, + "pixel_values": pixel_values, + } + + return batch + + def construct_chat_data(self, len_image, raw_lang): + messages = [ + { + "role": "user", + "content": [], + }, + ] + + for _i in range(len_image): + messages[0]["content"].append( + { + "type": "image", + "image": None, + } + ) + messages[0]["content"].append({"type": "text", "text": ""}) + messages[0]["content"][-1]["text"] = raw_lang + + return messages diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a3..4d9b24e8 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig @@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy + elif name == "dexvla": + from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy + + return DexVLAPolicy elif name == "pi0fast": from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy @@ -74,6 +79,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return VQBeTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) + elif policy_type == "dexvla": + return DexVLAConfig(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) else: diff --git a/pyproject.toml b/pyproject.toml index 4b858634..361e0857 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] +dexvla = ["transformers>=4.45.2", "qwen_vl_utils==0.0.10", "timm==0.9.10"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",