From 1aa4f0c08680c011e38fc71dcbd65269170b5a60 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Thu, 20 Feb 2025 17:29:21 +0800 Subject: [PATCH 01/36] add dexvla --- .../policies/dexvla/configuration_dexvla.py | 270 +++ .../common/policies/dexvla/fusion_modules.py | 40 + .../common/policies/dexvla/modeling_dexvla.py | 2070 +++++++++++++++++ .../policies/dexvla/robot_data_processor.py | 151 ++ 4 files changed, 2531 insertions(+) create mode 100644 lerobot/common/policies/dexvla/configuration_dexvla.py create mode 100644 lerobot/common/policies/dexvla/fusion_modules.py create mode 100644 lerobot/common/policies/dexvla/modeling_dexvla.py create mode 100644 lerobot/common/policies/dexvla/robot_data_processor.py diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py new file mode 100644 index 00000000..d634bfa6 --- /dev/null +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging +from transformers import AutoModel, AutoConfig + +logger = logging.get_logger(__name__) + + +class Qwen2VLAVisionConfig(PretrainedConfig): + model_type = "dex_vla" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DexVLAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig + + >>> # Initializing a Qwen2VL style configuration + >>> configuration = Qwen2VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_vla" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + # For loading policy head + policy_head_type='dit_diffusion_policy', # dit_diffusion_policy + policy_head_size='DiT_L', + action_dim=10, + state_dim=7, + non_lora_lr=1e-4, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLAVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLAVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for loading policy head + self.policy_head_type = policy_head_type + # if policy_head_type == 'dit_diffusion_policy': + # # self.policy_head_size = kwargs.get("policy_head_size", "none") + # self.policy_head_size = policy_head_size + # # self.policy_head_config = register_configuration_class(self.policy_head_type, model_size=policy_head_size) + # self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type, + # model_size=self.policy_head_size, + # global_cond_dim=hidden_size, action_dim=action_dim, + # state_dim=state_dim) + # elif policy_head_type == 'unet_diffusion_policy': + # self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type, + # global_cond_dim=hidden_size, action_dim=action_dim, + # state_dim=state_dim) + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + +from transformers import AutoConfig +AutoConfig.register("dex_vla", DexVLAConfig) diff --git a/lerobot/common/policies/dexvla/fusion_modules.py b/lerobot/common/policies/dexvla/fusion_modules.py new file mode 100644 index 00000000..7eb452e0 --- /dev/null +++ b/lerobot/common/policies/dexvla/fusion_modules.py @@ -0,0 +1,40 @@ +import torch.nn as nn + +class ActionProjector(nn.Module): + def __init__(self, in_dim, out_dim=1024): + super(ActionProjector, self).__init__() + self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + self.mlps = nn.ModuleList([ + # nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.Dropout(0.0), + ] + ) + + def forward(self, x): + x = self.global_1d_pool(x.permute(1, 0)).permute(1, 0) + for mlp in self.mlps: + x = mlp(x) + return x + + +class FiLM(nn.Module): + def __init__(self, feature_dim, condition_dim): + super(FiLM, self).__init__() + self.scale_fc = nn.Linear(condition_dim, feature_dim) + self.shift_fc = nn.Linear(condition_dim, feature_dim) + + nn.init.zeros_(self.scale_fc.weight) + nn.init.zeros_(self.scale_fc.bias) + nn.init.zeros_(self.shift_fc.weight) + nn.init.zeros_(self.shift_fc.bias) + + def forward(self, x, condition): + # 计算缩放和偏移参数 + scale = self.scale_fc(condition) + shift = self.shift_fc(condition) + + # 应用 FiLM 调制 + return x * (1 + scale) + shift diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py new file mode 100644 index 00000000..4e6a7d11 --- /dev/null +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -0,0 +1,2070 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2-VL model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss, LayerNorm + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, +) +from fusion_modules import ActionProjector, FiLM +from types import SimpleNamespace + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_dexvla import DexVLAConfig, Qwen2VLAVisionConfig +import gc + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + from transformers.modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + +from transformers import AutoConfig, AutoModel + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2VLConfig" + + +@dataclass +class Qwen2VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2VLRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[DexVLAConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = LayerNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class VisionMlp(nn.Module): + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> torch.Tensor: + return self.fc2(self.act(self.fc1(x))) + + +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) + self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.embed_dim, num_heads=config.num_heads + ) + self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(torch.bfloat16), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: DexVLAConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2VLRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + 1 + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLFlashAttention2(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLSdpaAttention(Qwen2VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": Qwen2VLSdpaAttention, +} + + +class Qwen2VLDecoderLayer(nn.Module): + def __init__(self, config: DexVLAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLPreTrainedModel(PreTrainedModel): + config_class = DexVLAConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock", "policy_head"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): + config_class = Qwen2VLAVisionConfig + _no_split_modules = ["Qwen2VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = PatchMerger( + dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size + ) + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + return self.merger(hidden_states) + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLModel(Qwen2VLPreTrainedModel): + def __init__(self, config: DexVLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: DexVLAConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + #print('@'*50) + #print(attention_mask.shape) + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + +class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.model = Qwen2VLModel(config) + self.vocab_size = config.vocab_size + + self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.with_llm_head = config.with_llm_head + # self.with_external_vit = config.with_external_vit + self.with_text_fcs = config.with_text_fcs + self.only_using_input_embeddings = config.only_using_input_embeddings + self.using_film = config.using_film + self.using_xattn = config.using_xattn + + self.llm_loss_weight = config.llm_loss_weight + + if isinstance(config.policy_head_config, dict): + config.policy_head_config = AutoConfig.for_model(**config.policy_head_config) + self.policy_head = AutoModel.from_config(config=config.policy_head_config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + if config.policy_head_config.model_type == "dit_diffusion_policy": + self.policy_head.init_weights() + self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + + if self.using_film: + self.reasoning_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + self.reasoning_film = FiLM(feature_dim=config.hidden_size, condition_dim=config.hidden_size) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + if getattr(outputs, "rope_deltas", None) is not None: + model_kwargs["rope_deltas"] = outputs.rope_deltas + + return model_kwargs + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + actions: Optional[torch.LongTensor] = None, + states: Optional[torch.FloatTensor] = None, + is_pad: bool = False, + is_eval: bool = False, + tinyvla: bool = False, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + self.computed_type = torch.bfloat16 + input_ids = input_ids.to("cuda") + attention_mask = attention_mask.to("cuda") + if not is_eval: + labels = labels.to("cuda") + actions = actions.to(dtype=self.computed_type, device='cuda') + states = states.to(dtype=self.computed_type, device='cuda') + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda') + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if tinyvla and is_eval: # dex-vla supports tinyvla-style VLA + return hidden_states + + logits = self.lm_head(hidden_states) + logits = logits.float() + + llm_loss = None + + # cross-entropy loss for VLM + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + llm_loss = loss_fct(shift_logits, shift_labels) + + # for evaluation + if is_eval: + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + if self.using_film: + action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, + hidden_states=hidden_states) + else: # tinyvla + action_hidden_states = hidden_states + + ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) + + loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss, + 'llm_loss': llm_loss, + 'action_loss': ret['loss']} + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + torch.cuda.empty_cache() + gc.collect() + del input_ids + del attention_mask + del position_ids + del past_key_values + del inputs_embeds + del labels + del pixel_values + del image_grid_thw + del actions + del states + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def film_forward(self, labels, input_ids, hidden_states): + """ + Perform the forward pass for the film module. + """ + inputs_index = labels[:, :] == -100 + inputs_index = inputs_index.int() + + xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:]) + indexs = torch.argmax((xor_array != 0).float(), dim=1) + input_embeddings = [] + reasoning_embeddings = [] + identity = [] + for i in range(indexs.shape[0]): + end = indexs[i] + 1 + temp = input_ids[i] == 151643 # pad token id for qwen2_vl + start = sum(temp.int()) + input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :])) + identity.append(torch.mean(hidden_states[i, start:end, :], dim=0)) + + reasoning_embeddings.append(self.reasoning_action_proj(hidden_states[i, end:, :])) + input_embeddings = torch.cat(input_embeddings, dim=0) + reasoning_embeddings = torch.cat(reasoning_embeddings, dim=0) + identity = torch.stack(identity) + + action_hidden_states = self.reasoning_film(input_embeddings, reasoning_embeddings).unsqueeze(1) + + action_hidden_states = action_hidden_states + identity.unsqueeze(1) + return action_hidden_states + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + } + ) + model_inputs.update(kwargs) + return model_inputs + + def evaluate(self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + tokenizer=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to('cuda') + with torch.inference_mode(): + outputs = self.generate( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + is_eval=is_eval, + num_beams=1, + do_sample=False, + temperature=0.2, + max_new_tokens=256, + eos_token_id=tokenizer.eos_token_id, # End of sequence token + pad_token_id=tokenizer.eos_token_id, # Pad token + use_cache=True, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + output_ids = outputs.sequences + input_token_len = input_ids.shape[1] + n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0] + + outputs_text = outputs_text.strip() + last_hidden_states = [each[-1] for each in outputs.hidden_states] # all hidden states + all_hidden_states = torch.cat(last_hidden_states, dim=1) + + action_hidden_states = None + + if self.using_film: + action_hidden_states = self.film_forward(labels=torch.ones_like(output_ids), + input_ids=output_ids, + hidden_states=torch.cat(last_hidden_states, dim=1)) + + action = self.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) + return action, outputs_text + + def evaluate_tinyvla(self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to('cuda') + with torch.inference_mode(): + all_hidden_states = self.forward(input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + is_eval=is_eval, + tinyvla=True) + + all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) + + action = self.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad) + return action, "tinyvla no output" + + +from transformers import AutoModelForCausalLM +AutoModelForCausalLM.register(DexVLAConfig, DexVLAPolicy) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py new file mode 100644 index 00000000..1f90dc79 --- /dev/null +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -0,0 +1,151 @@ +from PIL import Image +import numpy as np +from torchvision.transforms.functional import to_pil_image, to_tensor +import torchvision.transforms as transforms +import torch +from qwen_vl_utils import process_vision_info +from qwen_vl_utils import fetch_image +class Qwen2VLAProcess: + def __init__( + self, + language=None, + tokenizer=None, + max_seq_len=512, + multimodal_processor=None, + camera_names=None, + data_args=None, + ): + super().__init__() + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.camera_names = camera_names + # self.language = language + self.multimodal_processor = multimodal_processor + self.data_args = data_args + + def preprocess_image(self, image, size=224): + # Model has been trained to handle images of different aspects ratios + # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize + # options are helpful to improve quality in some tasks. + image = np.asarray(image) + if image.ndim == 2: # Convert image without last channel into greyscale. + image = np.stack((image,) * 3, axis=-1) + image = image[..., :3] # Remove alpha layer. + assert image.shape[-1] == 3 + + image_pil = to_pil_image(image) + + # Step 2: Define the resize transformation + resize_transform = transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR) + + # Step 3: Apply the resize transformation + image_resized_pil = resize_transform(image_pil) + + # Step 4: Convert back to tensor if needed + image_resized = to_tensor(image_resized_pil) + return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1] + + def qwen2_image_preprocess(self, each, camera_name): + ele = {} + each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)) + ele['image'] = each + if 'wrist' in camera_name: + w, h = eval(self.data_args.image_size_wrist) + ele['resized_height'] = h + ele['resized_width'] = w + else: + ele['resized_height'] = each.height + ele['resized_width'] = each.width + each = fetch_image(ele) + return torch.from_numpy(np.array(each)) + + def forward_process(self, sample, use_reasoning=True): + if sample['image'].ndim == 5 and sample['image'].shape[1] > 2: + video = True + else: + video = False + messages = self.datastruct_droid2llava(sample, video=video) + + data_dict = dict( + messages=messages, + images=None + ) + + image_data = torch.chunk(sample['image'], sample['image'].shape[0], 0) + + images_list = [] + + for i, each in enumerate(image_data): + if each.ndim == 4: + img_pil = self.qwen2_image_preprocess(each, self.camera_names[i]) + else: + img_pil = [] + for temp in each.squeeze(0): + img_pil.append(self.qwen2_image_preprocess(temp, self.camera_names[i])) + img_pil = torch.stack(img_pil, 0) + images_list.append(img_pil) + # TODO RESIZE + # image_data = image_data / 255.0 + if video: + image_data = None + video_inputs = images_list + else: + image_data = images_list + video_inputs = None + + text = self.multimodal_processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + model_inputs = self.multimodal_processor( + text=text, + images=image_data, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + input_labels = torch.ones_like(model_inputs['input_ids']) * -100 + if use_reasoning: + answer = sample['reasoning'] + "Next action:" + '<|im_end|>' + else: + answer = '' + '<|im_end|>' + + output_text = self.tokenizer(answer, padding=True, return_tensors="pt") + output_labels = output_text['input_ids'] + model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1) + model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1) + labels = torch.cat((input_labels, output_labels), dim=-1) + data_dict['state'] = sample['state'] + data_dict['action'] = sample['action'] + data_dict['is_pad'] = sample['is_pad'] + data_dict['labels'] = labels + for k, v in model_inputs.items(): + data_dict[k] = v + return data_dict + + def datastruct_droid2llava(self, sample, video=False): + len_image = sample['image'].shape[0] + + messages = [ + { + "role": "user", + "content": [], + }, + # {"role": "assistant", "content": f''}, + ] + + for i in range(len_image): + if video: + messages[0]['content'].append({ + "type": "video", + "video": None, + }) + else: + messages[0]['content'].append({ + "type": "image", + "image": None, + }) + messages[0]['content'].append({"type": "text", "text": f""}) + messages[0]['content'][-1]['text'] = sample['raw_lang'] + + return messages \ No newline at end of file From 20f346956acd8a4adf90f95e5839cc30fe054525 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Thu, 20 Feb 2025 17:29:29 +0800 Subject: [PATCH 02/36] add policy head --- .../policy_head/configuration_scaledp.py | 95 +++ .../dexvla/policy_head/modeling_scaledp.py | 552 ++++++++++++++++++ 2 files changed, 647 insertions(+) create mode 100644 lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py create mode 100644 lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py diff --git a/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py new file mode 100644 index 00000000..cb771847 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py @@ -0,0 +1,95 @@ +import os +from typing import Union, List +from transformers import PretrainedConfig + +from transformers.utils import logging +from transformers import AutoConfig, AutoModelForCausalLM +logger = logging.get_logger(__name__) + +MODEL_STRUCTURE = { + 'ScaleDP_H': {'depth': 32, 'n_emb': 1280, 'num_heads': 16, }, + 'ScaleDP_L': {'depth': 24, 'n_emb': 1024, 'num_heads': 16, }, # 400M +} + +class ScaleDPPolicyConfig(PretrainedConfig): + ''' + Configuration for ScaleDP policy head + ''' + model_type = "scale_dp_policy" + def __init__( + self, + eval: bool = False, + action_dim: int = 14, # action dim + # output_dim: int = 14, # action dim + cond_dim: int = 1536, # the input dim of the condition + state_dim: int = 14, # the input dim of the state + prediction_horizon: int = 16, # horizon + n_obs_steps: int = 2, # number of observation steps + depth: int = 28, # number of DiT blocks + n_emb: int = 256, # embedding size + num_heads: int = 16, + mlp_ratio: int = 4.0, + time_as_cond: bool = True, + obs_as_cond: bool = True, + learn_sigma: bool = False, + model_size: str = "none", + num_inference_timesteps: int = 10, + num_queries: int = 16, + noise_samples: int = 1, + num_train_timesteps: int = 100, + is_tinyvla: bool = False, + **kwargs + ): + if model_size != "none": + depth = MODEL_STRUCTURE[model_size]['depth'] + n_emb = MODEL_STRUCTURE[model_size]['n_emb'] + num_heads = MODEL_STRUCTURE[model_size]['num_heads'] + else: + # raise ValueError("model_size show not be 'none'") + pass + # print("model_size should not be 'none'") + self.eval = eval + + self.input_dim = action_dim + self.output_dim = action_dim + self.prediction_horizon = prediction_horizon + + self.is_tinyvla = is_tinyvla + + self.cond_dim = cond_dim + self.state_dim = state_dim + + self.n_obs_steps = n_obs_steps + self.depth = depth + self.n_emb = n_emb + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.time_as_cond = time_as_cond + self.obs_as_cond = obs_as_cond + self.learn_sigma = learn_sigma + + self.num_inference_timesteps = num_inference_timesteps + self.num_queries = prediction_horizon + self.noise_samples = noise_samples + self.num_train_timesteps = num_train_timesteps + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "llava_pythia": + config_dict = config_dict["action_head"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + +AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py new file mode 100644 index 00000000..41df2c8f --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +from typing import Tuple + +import timm +import numpy as np +import logging + +import math +from typing import Tuple + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.jit import Final +from timm.models.vision_transformer import Mlp, use_fused_attn +from transformers.modeling_utils import PreTrainedModel +from transformers import AutoModel, AutoModelForCausalLM + +_logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + # attn = q @ k.transpose(-2, -1) + # if attn_mask is not None: + # attn += attn_mask + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + + # Add attention mask if provided + if attn_mask is not None: + attn_scores += attn_mask + + # Apply softmax to get attention weights (softmax is applied along the last dimension) + attn_weights = F.softmax(attn_scores, dim=-1) + + # Dropout on attention weights (if dropout is used) + attn_weights = self.attn_drop(attn_weights) + + # Apply attention weights to value tensor (V) + x = torch.matmul(attn_weights, v) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +logger = logging.getLogger(__name__) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.bfloat16) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding.to(dtype=torch.bfloat16) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + + +################################################################################# +# Core ScaleDP Model # +################################################################################# + +class ScaleDPBlock(nn.Module): + """ + A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, attn_mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) # norm, scale&shift, attn, scale, + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of ScaleDP. + """ + + def __init__(self, hidden_size, output_dim): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, output_dim, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +from .configuration_scaledp import ScaleDPPolicyConfig +class ScaleDP(PreTrainedModel): + """ + Diffusion models with a Transformer backbone. + """ + config_class = ScaleDPPolicyConfig + def __init__( + self, + config: ScaleDPPolicyConfig, + ): + super().__init__(config) + # compute number of tokens for main trunk and conScaleDPion encoder + if config.n_obs_steps is None: + config.n_obs_steps = config.prediction_horizon + T = config.prediction_horizon + T_cond = 1 + if not config.time_as_cond: + T += 1 + T_cond -= 1 + obs_as_cond = config.cond_dim > 0 + if obs_as_cond: + assert config.time_as_cond + T_cond += config.n_obs_steps + + self.is_tinyvla = config.is_tinyvla + if config.is_tinyvla: + self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + self.norm_after_pool = nn.LayerNorm(config.cond_dim) + # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) + self.combine = nn.Sequential( + nn.Linear(config.cond_dim+config.state_dim, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, config.cond_dim) + ) + self.learn_sigma = config.learn_sigma + self.input_dim = config.input_dim + self.output_dim = config.output_dim * 2 if config.learn_sigma else config.output_dim + self.num_heads = config.num_heads + + self.x_embedder = nn.Linear(config.input_dim, config.n_emb) + self.t_embedder = TimestepEmbedder(config.n_emb) + self.cond_obs_emb = None + if obs_as_cond: + self.cond_obs_emb = nn.Linear(config.cond_dim, config.n_emb) + + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb)) + + self.blocks = nn.ModuleList([ + ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) for _ in range(config.depth) + ]) + self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim) + # self.initialize_weights() + # constants + self.T = T + self.T_cond = T_cond + self.prediction_horizon = config.prediction_horizon + self.time_as_cond = config.time_as_cond + self.action_dim = config.output_dim + self.obs_as_cond = obs_as_cond + logger.info( + "number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()) + ) + + from diffusers.schedulers.scheduling_ddim import DDIMScheduler + self.num_inference_timesteps = config.num_inference_timesteps + # self.proj_to_action = nn.Identity() + self.noise_scheduler = DDIMScheduler( + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule='squaredcos_cap_v2', + clip_sample=True, + set_alpha_to_one=True, + steps_offset=0, + prediction_type='epsilon' + ) + self.num_queries = config.num_queries #16 + self.noise_samples = config.noise_samples # 1 + # self.num_inference_timesteps = config.num_inference_timesteps # 100 + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.normal_(self.pos_embed, mean=0.0, std=0.02) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.cond_obs_emb.weight, mean=0.0, std=0.02) + nn.init.constant_(self.cond_obs_emb.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in ScaleDP blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def get_optim_groups(self, weight_decay: float = 1e-3): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the models into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, Attention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.startswith("bias"): + # MultiheadAttention bias starts with "bias" + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + return optim_groups + + def configure_optimizers(self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95)): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + optimizer = torch.optim.AdamW( + optim_groups, lr=learning_rate, betas=betas + ) + return optimizer + + def forward(self, actions, hidden_states, states, is_pad): + """ + Forward pass for the diffusion head. + :param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [B,Tokens, D] 8 1200 1024 + :param states: robot states, shape [B, D] + :return: loss + """ + if actions is not None: # training time + B = actions.size(0) + actions = actions[:, :self.num_queries] + is_pad = is_pad[:, :self.num_queries] + num_noise_samples = self.noise_samples + # sample noise to add to actions + noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, + dtype=actions.dtype) # num_noise, B, Ta, D(1, 2, 16, 14) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=actions.device + ).long() + + timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) + + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = torch.cat([self.noise_scheduler.add_noise( + actions, noise[i], timesteps) + for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + + noisy_actions = noisy_actions.to(dtype=actions.dtype) + assert hidden_states.ndim == 3 + + hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) + timesteps = timesteps.repeat(num_noise_samples) + is_pad = is_pad.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) + + noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = (loss * ~is_pad.unsqueeze(-1)).mean() + # loss_dict['loss'] = loss + return {'loss': loss} + # return loss + else: # inference time + B = 1 + Tp = self.num_queries + action_dim = self.action_dim + + # initialize action from Guassian noise + noisy_action = torch.randn((B, Tp, action_dim)).cuda() + + naction = noisy_action.to(dtype=hidden_states.dtype) + # init scheduler + self.noise_scheduler.set_timesteps(self.num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states) + + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=naction + ).prev_sample + + return naction + + def model_forward(self, x, t, global_cond, states): + """ + Forward pass of ScaleDP. + x: (N, T, input_dim) noisy actions + t: (N,) tensor of diffusion timesteps + global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings + """ + if self.is_tinyvla: + global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) + global_cond = self.norm_after_pool(global_cond) + else: + global_cond = global_cond.squeeze(1) + global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond + global_cond = self.combine(global_cond) + + if not torch.is_tensor(t): + t = torch.tensor([t], dtype=torch.long, device=x.device) + elif torch.is_tensor(t) and len(t.shape) == 0: + t = t[None].to(x.device) + t = t.expand(t.shape[0]) + + x = self.x_embedder(x) + self.pos_embed.to(device=x.device, dtype=x.dtype) # (N, T, D), where T = prediction_horizon + t = self.t_embedder(t) # (N, D) + if self.obs_as_cond: + global_cond = self.cond_obs_emb(global_cond) # (N, D) + # c = t + global_cond.sum(dim=1) # (N, D) + c = t + global_cond # (N, D) + for block in self.blocks: + # x = block(x, c, attn_mask=self.mask) # (N, T, D) + x = block(x, c, attn_mask=None) # (N, T, D) + x = self.final_layer(x, c) # (N, T, output_dim) + return x + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# ScaleDP Configs # +################################################################################# + +def ScaleDP_H(**kwargs): + return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs) + +def ScaleDP_L(**kwargs): + return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) + + + +AutoModel.register(ScaleDPPolicyConfig, ScaleDP) From 5701f02ea8ece18a8aa7952910a6b8351e39ed23 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Thu, 20 Feb 2025 17:40:59 +0800 Subject: [PATCH 03/36] add create config and policy for dexvla --- lerobot/common/policies/__init__.py | 1 + lerobot/common/policies/factory.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 2e4486ef..5a7d1b8a 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -3,3 +3,4 @@ from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfi from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig +from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index cd440f7a..e7777367 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -26,6 +26,7 @@ from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy + elif name == "dexvla": + from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy + + return DexVLAPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -70,6 +75,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return VQBeTConfig(**kwargs) elif policy_type == "pi0": return PI0Config(**kwargs) + elif policy_type == "dexvla": + return DexVLAConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") From 15fdeb382c005a8793f76e96043e0b8e1bdc4be8 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Thu, 20 Feb 2025 18:29:01 +0800 Subject: [PATCH 04/36] remove unused code --- .../policies/dexvla/configuration_dexvla.py | 21 +++------ .../common/policies/dexvla/modeling_dexvla.py | 46 ++----------------- .../policy_head/configuration_scaledp.py | 3 -- .../dexvla/policy_head/modeling_scaledp.py | 10 +--- 4 files changed, 11 insertions(+), 69 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index d634bfa6..2ba44f44 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -208,7 +208,7 @@ class DexVLAConfig(PretrainedConfig): policy_head_size='DiT_L', action_dim=10, state_dim=7, - non_lora_lr=1e-4, + chunk_size=50, **kwargs, ): if isinstance(vision_config, dict): @@ -228,18 +228,11 @@ class DexVLAConfig(PretrainedConfig): # for loading policy head self.policy_head_type = policy_head_type - # if policy_head_type == 'dit_diffusion_policy': - # # self.policy_head_size = kwargs.get("policy_head_size", "none") - # self.policy_head_size = policy_head_size - # # self.policy_head_config = register_configuration_class(self.policy_head_type, model_size=policy_head_size) - # self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type, - # model_size=self.policy_head_size, - # global_cond_dim=hidden_size, action_dim=action_dim, - # state_dim=state_dim) - # elif policy_head_type == 'unet_diffusion_policy': - # self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type, - # global_cond_dim=hidden_size, action_dim=action_dim, - # state_dim=state_dim) + self.policy_head_config = AutoConfig.for_model(model_type=policy_head_type, + model_size=policy_head_size, + cond_dim=hidden_size, action_dim=action_dim, + prediction_horizon=chunk_size, + state_dim=state_dim) # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -266,5 +259,3 @@ class DexVLAConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -from transformers import AutoConfig -AutoConfig.register("dex_vla", DexVLAConfig) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 4e6a7d11..9af6853f 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1455,12 +1455,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): self.vocab_size = config.vocab_size self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides - self.with_llm_head = config.with_llm_head - # self.with_external_vit = config.with_external_vit - self.with_text_fcs = config.with_text_fcs - self.only_using_input_embeddings = config.only_using_input_embeddings - self.using_film = config.using_film - self.using_xattn = config.using_xattn self.llm_loss_weight = config.llm_loss_weight @@ -1472,8 +1466,8 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): # Initialize weights and apply final processing self.post_init() - if config.policy_head_config.model_type == "dit_diffusion_policy": - self.policy_head.init_weights() + + self.policy_head.init_weights() self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size) if self.using_film: @@ -1688,7 +1682,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): states: Optional[torch.FloatTensor] = None, is_pad: bool = False, is_eval: bool = False, - tinyvla: bool = False, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Args: @@ -1802,8 +1795,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): ) hidden_states = outputs[0] - if tinyvla and is_eval: # dex-vla supports tinyvla-style VLA - return hidden_states logits = self.lm_head(hidden_states) logits = logits.float() @@ -1839,11 +1830,9 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): rope_deltas=rope_deltas, ) - if self.using_film: - action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, + action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, hidden_states=hidden_states) - else: # tinyvla - action_hidden_states = hidden_states + ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) @@ -2041,30 +2030,3 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): action = self.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) return action, outputs_text - def evaluate_tinyvla(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): - input_ids = input_ids.to('cuda') - with torch.inference_mode(): - all_hidden_states = self.forward(input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - image_grid_thw=image_grid_thw, - is_eval=is_eval, - tinyvla=True) - - all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) - - action = self.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad) - return action, "tinyvla no output" - - -from transformers import AutoModelForCausalLM -AutoModelForCausalLM.register(DexVLAConfig, DexVLAPolicy) diff --git a/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py index cb771847..0837f499 100644 --- a/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py @@ -34,10 +34,8 @@ class ScaleDPPolicyConfig(PretrainedConfig): learn_sigma: bool = False, model_size: str = "none", num_inference_timesteps: int = 10, - num_queries: int = 16, noise_samples: int = 1, num_train_timesteps: int = 100, - is_tinyvla: bool = False, **kwargs ): if model_size != "none": @@ -54,7 +52,6 @@ class ScaleDPPolicyConfig(PretrainedConfig): self.output_dim = action_dim self.prediction_horizon = prediction_horizon - self.is_tinyvla = is_tinyvla self.cond_dim = cond_dim self.state_dim = state_dim diff --git a/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py index 41df2c8f..4c78b6e1 100644 --- a/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py @@ -219,10 +219,6 @@ class ScaleDP(PreTrainedModel): assert config.time_as_cond T_cond += config.n_obs_steps - self.is_tinyvla = config.is_tinyvla - if config.is_tinyvla: - self.global_1d_pool = nn.AdaptiveAvgPool1d(1) - self.norm_after_pool = nn.LayerNorm(config.cond_dim) # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) self.combine = nn.Sequential( nn.Linear(config.cond_dim+config.state_dim, 1024), @@ -456,11 +452,7 @@ class ScaleDP(PreTrainedModel): t: (N,) tensor of diffusion timesteps global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings """ - if self.is_tinyvla: - global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) - global_cond = self.norm_after_pool(global_cond) - else: - global_cond = global_cond.squeeze(1) + global_cond = global_cond.squeeze(1) global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond global_cond = self.combine(global_cond) From 5a67ce16bf926e0c264afdc5c11724381913e664 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Sun, 23 Feb 2025 15:20:55 +0800 Subject: [PATCH 05/36] update policy_heads --- .../configuration_scaledp.py | 0 .../configuration_unet_diffusion.py | 62 +++ .../modeling_scaledp.py | 0 .../policy_heads/modeling_unet_diffusion.py | 358 ++++++++++++++++++ 4 files changed, 420 insertions(+) rename lerobot/common/policies/dexvla/{policy_head => policy_heads}/configuration_scaledp.py (100%) create mode 100644 lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py rename lerobot/common/policies/dexvla/{policy_head => policy_heads}/modeling_scaledp.py (100%) create mode 100644 lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py diff --git a/lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py similarity index 100% rename from lerobot/common/policies/dexvla/policy_head/configuration_scaledp.py rename to lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py new file mode 100644 index 00000000..38e403a6 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -0,0 +1,62 @@ +import os +from typing import Union, List +from transformers import PretrainedConfig + +from transformers.utils import logging +from transformers import AutoConfig, AutoModelForCausalLM +logger = logging.get_logger(__name__) + +class UnetDiffusionPolicyConfig(PretrainedConfig): + ''' + Configuration for dit diffusion policy head + ''' + model_type = "unet_diffusion_policy" + + def __init__( + self, + action_dim=10, + global_cond_dim=2048, + diffusion_step_embed_dim=256, + down_dims=[256, 512, 1024], + kernel_size=5, + n_groups=8, + state_dim=7, + prediction_horizon=16, + noise_samples=1, + num_inference_timesteps=10, + num_train_timesteps=100, + **kwargs + ): + self.input_dim = action_dim + self.noise_samples = noise_samples + self.prediction_horizon = prediction_horizon + self.num_inference_timesteps = num_inference_timesteps + self.global_cond_dim = global_cond_dim + self.diffusion_step_embed_dim = diffusion_step_embed_dim + self.down_dims = down_dims + self.kernel_size = kernel_size + self.n_groups = n_groups + self.state_dim = state_dim + self.num_train_timesteps = num_train_timesteps + + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "llava_pythia": + config_dict = config_dict["action_head"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + +AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py similarity index 100% rename from lerobot/common/policies/dexvla/policy_head/modeling_scaledp.py rename to lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py new file mode 100644 index 00000000..a7b456d2 --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -0,0 +1,358 @@ +""" +Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi +""" +from typing import Callable, Union +import math +from collections import OrderedDict, deque +from packaging.version import parse as parse_version +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +# requires diffusers==0.11.1 +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.training_utils import EMAModel +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from transformers.modeling_utils import PreTrainedModel +from transformers import AutoModel, AutoModelForCausalLM +import copy +# =================== UNet for Diffusion ============== + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim, dtype): + super().__init__() + self.dim = dim + self.dtype=dtype + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device, dtype=self.dtype) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class ConditionalResidualBlock1D(nn.Module): + def __init__(self, + in_channels, + out_channels, + cond_dim, + kernel_size=3, + n_groups=8): + super().__init__() + + self.blocks = nn.ModuleList([ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ]) + + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels * 2 + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + nn.Unflatten(-1, (-1, 1)) + ) + + # make sure dimensions compatible + self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, x, cond): + ''' + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + ''' + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + + embed = embed.reshape( + embed.shape[0], 2, self.out_channels, 1) + scale = embed[:, 0, ...] + bias = embed[:, 1, ...] + out = scale * out + bias + + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(PreTrainedModel): + _no_split_modules = ["mid_modules", "down_modules", "up_modules"] + + config_class = UnetDiffusionPolicyConfig + def __init__(self, + config: UnetDiffusionPolicyConfig + ): + """ + input_dim: Dim of actions. + global_cond_dim: Dim of global conditioning applied with FiLM + in addition to diffusion step embedding. This is usually obs_horizon * obs_dim + diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k + down_dims: Channel size for each UNet level. + The length of this array determines numebr of levels. + kernel_size: Conv kernel size + n_groups: Number of groups for GroupNorm + """ + + super().__init__(config) + all_dims = [config.input_dim] + list(config.down_dims) + start_dim = config.down_dims[0] + + self.num_queries = config.prediction_horizon + self.noise_samples = config.noise_samples + # self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + # self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim) + self.norm_after_pool = nn.LayerNorm(config.global_cond_dim) + self.combine = nn.Linear(config.global_cond_dim+config.state_dim, config.global_cond_dim) + dsed = config.diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed, torch.bfloat16), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + config.global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:])) + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList([ + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups + ), + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups + ), + ]) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups), + ConditionalResidualBlock1D( + dim_out, dim_out, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_out * 2, dim_in, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups), + ConditionalResidualBlock1D( + dim_in, dim_in, cond_dim=cond_dim, + kernel_size=config.kernel_size, n_groups=config.n_groups), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size), + nn.Conv1d(start_dim, config.input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + print("number of parameters: {:e}".format( + sum(p.numel() for p in self.parameters())) + ) + + from diffusers.schedulers.scheduling_ddim import DDIMScheduler + self.num_inference_timesteps = config.num_inference_timesteps + # self.proj_to_action = nn.Identity() + self.noise_scheduler = DDIMScheduler( + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule='squaredcos_cap_v2', + clip_sample=True, + set_alpha_to_one=True, + steps_offset=0, + prediction_type='epsilon' + ) + + # self.num_inference_timesteps = config.num_inference_timesteps # 100 + + def forward(self, actions, hidden_states, states, is_pad): + """ + Forward pass for the diffusion head. + :param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [B,Tokens, D] 8 1200 1024 + :param states: robot states, shape [B, D] + :return: loss + """ + if actions is not None: # training time + B = actions.size(0) + actions = copy.deepcopy(actions[:, :self.num_queries]) + is_pad = copy.deepcopy(is_pad[:, :self.num_queries]) + num_noise_samples = self.noise_samples + # sample noise to add to actions + noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, + dtype=actions.dtype) # num_noise, B, Ta, D + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=actions.device + ).long() + + timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) + + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = torch.cat([self.noise_scheduler.add_noise( + actions, noise[i], timesteps) + for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + + noisy_actions = noisy_actions.to(dtype=actions.dtype) + assert hidden_states.ndim == 3 + + hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) + timesteps = timesteps.repeat(num_noise_samples) + is_pad = is_pad.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) + + noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = (loss * ~is_pad.unsqueeze(-1)).mean() + # loss_dict['loss'] = loss + return {'loss': loss} + # return loss + else: # inference time + B = 1 + Tp = self.num_queries + action_dim = 14 + + # initialize action from Guassian noise + noisy_action = torch.randn((B, Tp, action_dim)).cuda() + + naction = noisy_action.to(dtype=hidden_states.dtype) + # init scheduler + self.noise_scheduler.set_timesteps(self.num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = self.model_forward(naction, k, global_cond=hidden_states, states=states) + + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=naction + ).prev_sample + + return naction + + def model_forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + global_cond=None, + states=None): + """ + x: (B,T,input_dim) + timestep: (B,) or int, diffusion step + global_cond: (B,global_cond_dim) + output: (B,T,input_dim) + """ + # (B,T,C) + sample = sample.moveaxis(-1, -2) + # (B,C,T) + # global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) + global_cond = global_cond.squeeze(1) + + global_cond = self.norm_after_pool(global_cond) + global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond + global_cond = self.combine(global_cond) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + + if global_cond is not None: + global_feature = torch.cat([ + global_feature, global_cond + ], axis=-1) + + x = sample + h = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + # (B,C,T) + x = x.moveaxis(-1, -2) + # (B,T,C) + return x + +AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) From d67fdf638d6c10882b07d33069fe4b2580a79447 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Sun, 23 Feb 2025 15:21:12 +0800 Subject: [PATCH 06/36] add qwen2_vla --- .../qwe2_vla/configuration_qwen2_vla.py | 252 +++ .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 1995 +++++++++++++++++ 2 files changed, 2247 insertions(+) create mode 100644 lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py create mode 100644 lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py new file mode 100644 index 00000000..a1a1d81f --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging +from transformers import AutoModel, AutoConfig + +logger = logging.get_logger(__name__) + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig + + >>> # Initializing a Qwen2VL style configuration + >>> configuration = Qwen2VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_vla" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + # For loading policy head + policy_head_type='scale_dp_policy', # unet_diffusion_policy + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.policy_head_type = policy_head_type # for loading policy head + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + +from transformers import AutoConfig +AutoConfig.register("qwen2_vla", Qwen2VLAConfig) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py new file mode 100644 index 00000000..e37fea19 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -0,0 +1,1995 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2-VL model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss, LayerNorm + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, +) +from lerobot.common.policies.dexvla.fusion_modules import * + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig +from transformers import AutoConfig, AutoModel +import gc + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + from transformers.modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2VLConfig" + + +@dataclass +class Qwen2VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2VLRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Qwen2VLAConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = LayerNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class VisionMlp(nn.Module): + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> torch.Tensor: + return self.fc2(self.act(self.fc1(x))) + + +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) + self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.embed_dim, num_heads=config.num_heads + ) + self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(torch.bfloat16), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2VLAConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2VLRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + 1 + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLFlashAttention2(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLSdpaAttention(Qwen2VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": Qwen2VLSdpaAttention, +} + + +class Qwen2VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2VLAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2VLAConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock", "policy_head"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): + config_class = Qwen2VLVisionConfig + _no_split_modules = ["Qwen2VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = PatchMerger( + dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size + ) + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + return self.merger(hidden_states) + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLModel(Qwen2VLPreTrainedModel): + def __init__(self, config: Qwen2VLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2VLAConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~policy_heads.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + +class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.model = Qwen2VLModel(config) + self.vocab_size = config.vocab_size + self.with_llm_head = config.with_llm_head + + self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.using_film = config.using_film + + self.llm_loss_weight = config.llm_loss_weight + + if isinstance(config.policy_head_config, dict): + config.policy_head_config = AutoConfig.for_model(**config.policy_head_config) + self.policy_head = AutoModel.from_config(config=config.policy_head_config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + if config.policy_head_config.model_type == "scale_dp_policy": + self.policy_head.init_weights() + self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + + if self.using_film: + # Initialize projection layers and condition modulation layers + self.reasoning_action_proj = ActionProjector(config.hidden_size, config.hidden_size) + self.reasoning_film = FiLM(feature_dim=config.hidden_size, condition_dim=config.hidden_size) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + if getattr(outputs, "rope_deltas", None) is not None: + model_kwargs["rope_deltas"] = outputs.rope_deltas + + return model_kwargs + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + actions: Optional[torch.LongTensor] = None, + states: Optional[torch.FloatTensor] = None, + is_pad: bool = False, + is_eval: bool = False, + tinyvla: bool = False, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + self.computed_type = torch.bfloat16 + input_ids=input_ids.to("cuda") + attention_mask=attention_mask.to("cuda") + if not is_eval: + labels = labels.to("cuda") + actions = actions.to(dtype=self.computed_type, device='cuda') + states = states.to(dtype=self.computed_type, device='cuda') + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda') + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if tinyvla: # dex-vla supports tinyvla-style VLA + return hidden_states + + if self.with_llm_head: + logits = self.lm_head(hidden_states) + logits = logits.float() + else: + logits = None + self.llm_head = None + + llm_loss = None + # cross-entropy loss for VLM + if labels is not None and self.with_llm_head: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + llm_loss = loss_fct(shift_logits, shift_labels) + + # for evaluation + if is_eval: + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + if self.using_film: + action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, + hidden_states=hidden_states) + else: # tinyvla + action_hidden_states = hidden_states + + ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) + + if self.with_llm_head: + loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss, + 'llm_loss': llm_loss, + 'action_loss': ret['loss']} + else: + loss = {'loss': ret['loss'], + 'llm_loss': (torch.ones(1)*(-100)).to(ret['loss'].dtype).squeeze(0), + 'action_loss': ret['loss']} + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + torch.cuda.empty_cache() + gc.collect() + del input_ids + del attention_mask + del position_ids + del past_key_values + del inputs_embeds + del labels + del pixel_values + del image_grid_thw + del actions + del states + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def film_forward(self, labels, input_ids, hidden_states): + """ + Perform the forward pass for the film module. + """ + inputs_index = labels[:, :] == -100 + inputs_index = inputs_index.int() + + xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:]) + indexs = torch.argmax((xor_array != 0).float(), dim=1) + input_embeddings = [] + reasoning_embeddings = [] + identity = [] + for i in range(indexs.shape[0]): + end = indexs[i] + 1 + temp = input_ids[i] == 151643 # pad token id for qwen2_vl + start = sum(temp.int()) + input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :])) + identity.append(torch.mean(hidden_states[i, start:end, :], dim=0)) + + reasoning_embeddings.append(self.reasoning_action_proj(hidden_states[i, end:, :])) + input_embeddings = torch.cat(input_embeddings, dim=0) + reasoning_embeddings = torch.cat(reasoning_embeddings, dim=0) + identity = torch.stack(identity) + + action_hidden_states = self.reasoning_film(input_embeddings, reasoning_embeddings).unsqueeze(1) + + action_hidden_states = action_hidden_states + identity.unsqueeze(1) + return action_hidden_states + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + } + ) + model_inputs.update(kwargs) + return model_inputs + + + +from transformers import AutoModelForCausalLM +AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) From d0f6fce0cbccd937301bc0c0d6dd365c231d0a44 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Sun, 23 Feb 2025 15:21:30 +0800 Subject: [PATCH 07/36] update dexvla --- .../policies/dexvla/configuration_dexvla.py | 355 +-- .../common/policies/dexvla/modeling_dexvla.py | 2151 ++--------------- 2 files changed, 313 insertions(+), 2193 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 2ba44f44..ae225c27 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -14,248 +14,141 @@ # limitations under the License. """Qwen2VL model configuration""" -import os -from typing import Union +from typing import Tuple -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation +from dataclasses import dataclass, field + +from transformers import AutoConfig + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) from transformers.utils import logging -from transformers import AutoModel, AutoConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.common.policies.dexvla.policy_heads.configuration_scaledp import ScaleDPPolicyConfig +from lerobot.common.policies.dexvla.policy_heads.configuration_unet_diffusion import UnetDiffusionPolicyConfig +from lerobot.common.policies.dexvla.qwe2_vla.configuration_qwen2_vla import Qwen2VLAConfig +from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) +@PreTrainedConfig.register_subclass("dexvla") +@dataclass +class DexVLAConfig(PreTrainedConfig): + # For loading policy head + policy_head_type: str = 'scale_dp_policy' + policy_head_size: str = 'ScaleDP_L' + action_dim: int = 14 + state_dim: int = 14 + chunk_size: int = 50 + n_action_steps: int = 50 + n_obs_steps: int = 1 + hidden_size: int = 1536 + qwen2_vla_path: str = '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' -class Qwen2VLAVisionConfig(PretrainedConfig): - model_type = "dex_vla" + pretrained_path: str = None # pretrained dexvla + using_film: bool = True + llm_loss_weight: float = 1.0 + with_llm_head: bool = True + using_reasoning: bool = True + resize_size: tuple = (240, 320) + # Training presets + optimizer_lr: float = 2e-5 + optimizer_betas: Tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + # "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MIN_MAX, + } + ) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - - if config_dict.get("model_type") == "qwen2_vl": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + def __post_init__(self): + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + if self.using_reasoning: + assert self.using_film, f"using_reasoning requires `using_film=True`" + assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`" + print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.") + else: + print(f"Warning:DexVLA recommend to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.") - return cls.from_dict(config_dict, **kwargs) + if self.policy_head_type == 'scale_dp_policy': + self.policy_head_config = AutoConfig.for_model( + model_type=self.policy_head_type, + model_size=self.policy_head_size, + cond_dim=self.hidden_size, + action_dim=self.action_dim, + prediction_horizon=self.chunk_size, + state_dim=self.state_dim + ) + elif self.policy_head_type == 'unet_diffusion': + self.policy_head_config = AutoConfig.for_model( + model_type=self.policy_head_type, + global_cond_dim=self.hidden_size, + action_dim=self.action_dim, + state_dim=self.state_dim + ) + else: + raise ValueError(f'Policy head type {self.policy_head_type} not supported') + + self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vla_path) + + def validate_features(self) -> None: + # TODO: implement value error + if not self.image_features and not self.env_state_feature: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + # for i in range(self.empty_cameras): + # key = f"observation.images.empty_camera_{i}" + # empty_camera = PolicyFeature( + # type=FeatureType.VISUAL, + # shape=(3, 480, 640), + # ) + # self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None -class DexVLAConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a - Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 152064): - Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2VLModel`] - hidden_size (`int`, *optional*, defaults to 8192): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 29568): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 80): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 64): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 80): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - vision_config (`Dict`, *optional*): - The config for the visual encoder initialization. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type - and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value - accordingly. - Expected contents: - `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - `factor` (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - `original_max_position_embeddings` (`int`, *optional*): - Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - `attention_factor` (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - `beta_fast` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - `beta_slow` (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - `short_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - `low_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - `high_freq_factor` (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - - ```python - >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig - - >>> # Initializing a Qwen2VL style configuration - >>> configuration = Qwen2VLConfig() - - >>> # Initializing a model from the Qwen2-VL-7B style configuration - >>> model = Qwen2VLForConditionalGeneration(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2_vla" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=152064, - hidden_size=8192, - intermediate_size=29568, - num_hidden_layers=80, - num_attention_heads=64, - num_key_value_heads=8, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-05, - use_cache=True, - tie_word_embeddings=False, - rope_theta=1000000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=80, - attention_dropout=0.0, - vision_config=None, - rope_scaling=None, - # For loading policy head - policy_head_type='dit_diffusion_policy', # dit_diffusion_policy - policy_head_size='DiT_L', - action_dim=10, - state_dim=7, - chunk_size=50, - **kwargs, - ): - if isinstance(vision_config, dict): - self.vision_config = Qwen2VLAVisionConfig(**vision_config) - elif vision_config is None: - self.vision_config = Qwen2VLAVisionConfig() - - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for loading policy head - self.policy_head_type = policy_head_type - self.policy_head_config = AutoConfig.for_model(model_type=policy_head_type, - model_size=policy_head_size, - cond_dim=hidden_size, action_dim=action_dim, - prediction_horizon=chunk_size, - state_dim=state_dim) - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.rope_scaling = rope_scaling - - # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. - # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations - # one can set it to "linear"/"dynamic" etc. to have scaled RoPE - # TODO: @raushan update config in the hub - if self.rope_scaling is not None and "type" in self.rope_scaling: - if self.rope_scaling["type"] == "mrope": - self.rope_scaling["type"] = "default" - self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self, ignore_keys={"mrope_section"}) - - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 9af6853f..bea643c4 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1,1998 +1,141 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2-VL model.""" - -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, LayerNorm +from torch import Tensor -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig +from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import ( + Qwen2VLForConditionalGenerationForVLA ) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - ModelOutput, -) -from fusion_modules import ActionProjector, FiLM -from types import SimpleNamespace +from lerobot.common.policies.pretrained import PreTrainedPolicy +from collections import deque +from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D +from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP +from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess +from transformers import AutoProcessor, AutoTokenizer +import torchvision.transforms as transforms -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_dexvla import DexVLAConfig, Qwen2VLAVisionConfig -import gc +class DexVLAPolicy(PreTrainedPolicy): + """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - - from transformers.modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None - -from transformers import AutoConfig, AutoModel - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "Qwen2VLConfig" - - -@dataclass -class Qwen2VLCausalLMOutputWithPast(ModelOutput): - """ - Base class for Qwen2VL causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - - -class Qwen2VLRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[DexVLAConfig] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids - # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - return hidden_states - - -class PatchMerger(nn.Module): - def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: - super().__init__() - self.hidden_size = context_dim * (spatial_merge_size**2) - self.ln_q = LayerNorm(context_dim, eps=1e-6) - self.mlp = nn.Sequential( - nn.Linear(self.hidden_size, self.hidden_size), - nn.GELU(), - nn.Linear(self.hidden_size, dim), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) - return x - - -class VisionMlp(nn.Module): - def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.act = ACT2FN[hidden_act] - self.fc2 = nn.Linear(hidden_dim, dim) - - def forward(self, x) -> torch.Tensor: - return self.fc2(self.act(self.fc1(x))) - - -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class VisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) - attn_output = self.proj(attn_output) - return attn_output - - -class VisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -QWEN2_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, - "sdpa": VisionSdpaAttention, -} - - -class Qwen2VLVisionBlock(nn.Module): - def __init__(self, config, attn_implementation: str = "sdpa") -> None: - super().__init__() - self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) - self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) - mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) - - self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.embed_dim, num_heads=config.num_heads - ) - self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) - - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states).to(torch.bfloat16), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb - ) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states - - -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2VLAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: DexVLAConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - self.rope_scaling = config.rope_scaling - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2VLFlashAttention2(Qwen2VLAttention): - """ - Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2VLSdpaAttention(Qwen2VLAttention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_VL_ATTENTION_CLASSES = { - "eager": Qwen2VLAttention, - "flash_attention_2": Qwen2VLFlashAttention2, - "sdpa": Qwen2VLSdpaAttention, -} - - -class Qwen2VLDecoderLayer(nn.Module): - def __init__(self, config: DexVLAConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2VL_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2VLConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", - QWEN2VL_START_DOCSTRING, -) -class Qwen2VLPreTrainedModel(PreTrainedModel): config_class = DexVLAConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock", "policy_head"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True + name = "dexvla" - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): - config_class = Qwen2VLAVisionConfig - _no_split_modules = ["Qwen2VLVisionBlock"] - - def __init__(self, config) -> None: - super().__init__(config) - self.spatial_merge_size = config.spatial_merge_size - - self.patch_embed = PatchEmbed( - patch_size=config.patch_size, - temporal_patch_size=config.temporal_patch_size, - in_channels=config.in_channels, - embed_dim=config.embed_dim, - ) - - head_dim = config.embed_dim // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - - self.blocks = nn.ModuleList( - [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] - ) - self.merger = PatchMerger( - dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size - ) - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) - - return self.merger(hidden_states) - - -@add_start_docstrings( - "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", - QWEN2VL_START_DOCSTRING, -) -class Qwen2VLModel(Qwen2VLPreTrainedModel): - def __init__(self, config: DexVLAConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( + def __init__( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # the hard coded `3` is for temporal, height and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.dim() == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, config: DexVLAConfig, - past_key_values: Cache, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2VLConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. """ - #print('@'*50) - #print(attention_mask.shape) - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask |= sliding_attend_mask - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - -QWEN2_VL_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses - [`Qwen2VLImageProcessor`] for processing images. - pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): - The tensors corresponding to the input videos. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses - [`Qwen2VLImageProcessor`] for processing videos. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. -""" - -class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): super().__init__(config) - self.visual = Qwen2VisionTransformerPretrainedModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats ) - self.model = Qwen2VLModel(config) - self.vocab_size = config.vocab_size - - self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides - - self.llm_loss_weight = config.llm_loss_weight - - if isinstance(config.policy_head_config, dict): - config.policy_head_config = AutoConfig.for_model(**config.policy_head_config) - self.policy_head = AutoModel.from_config(config=config.policy_head_config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - self.policy_head.init_weights() - self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size) - - if self.using_film: - self.reasoning_action_proj = ActionProjector(config.hidden_size, config.hidden_size) - self.reasoning_film = FiLM(feature_dim=config.hidden_size, condition_dim=config.hidden_size) - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def get_rope_index( - self, - input_ids: torch.LongTensor, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. - Examples: - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embeddin for text part. - Examples: - Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] - vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] - vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - text temporal position_ids: [3, 4, 5, 6, 7] - text height position_ids: [3, 4, 5, 6, 7] - text width position_ids: [3, 4, 5, 6, 7] - Here we calculate the text start position_ids as the max vision position_ids plus 1. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) - """ - spatial_merge_size = self.config.vision_config.spatial_merge_size - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id - mrope_position_deltas = [] - if image_grid_thw is not None or video_grid_thw is not None: - total_input_ids = input_ids - position_ids = torch.ones( - 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device - ) - image_index, video_index = 0, 0 - for i, input_ids in enumerate(total_input_ids): - if attention_mask is not None: - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) - - return position_ids, mrope_position_deltas - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - num_new_tokens=num_new_tokens, + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats ) - if getattr(outputs, "rope_deltas", None) is not None: - model_kwargs["rope_deltas"] = outputs.rope_deltas + for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']: + setattr(config.qwen2_vla_config, k, config.__dict__[k]) - return model_kwargs + self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) + self.model.requires_grad_(False) + self.model.policy_head.requires_grad_(True) + self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vla_path) + self.tokenizer = AutoTokenizer.from_pretrained( + config.qwen2_vla_path + ) + self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format - @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - actions: Optional[torch.LongTensor] = None, - states: Optional[torch.FloatTensor] = None, - is_pad: bool = False, - is_eval: bool = False, - ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - - >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, + self.resize_size = self.config.resize_size + ratio = 0.95 + self.transformations = [ + transforms.Resize(size=self.resize_size, antialias=True), + transforms.RandomCrop(size=[int(self.resize_size[0] * ratio), int(self.resize_size[1] * ratio)]), + transforms.Resize(self.resize_size, antialias=True), + transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False), + transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), # , hue=0.08) ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + self.reset() - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" + def process_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Applying DexVLA preprocessing to original data. Including resizing images. Scaling the range of actions, states.""" + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + present_img_keys = [key for key in self.config.image_features if key in batch] + task_descs = batch['task'] + try: + reasonings = batch['reasoning'] + except KeyError: + reasonings = ['no reasoning'] * len(task_descs) - self.computed_type = torch.bfloat16 - input_ids = input_ids.to("cuda") - attention_mask = attention_mask.to("cuda") - if not is_eval: - labels = labels.to("cuda") - actions = actions.to(dtype=self.computed_type, device='cuda') - states = states.to(dtype=self.computed_type, device='cuda') - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda') - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + pass + is_pad = batch['action_is_pad'] + all_cam_images = [] + for k in present_img_keys: + all_cam_images.append(batch[k]) - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + # construct observations, and scale 0-1 to 0-255 + image_data = torch.stack(all_cam_images) * 255 + image_data = image_data.to(dtype=torch.uint8) + # construct observations + qpos_data = batch['observation.state'].float() + action_data = batch['action'].float() - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + orig_shape = image_data.shape + image_data = image_data.view(-1, *orig_shape[2:]) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) + for transform in self.transformations: + image_data = transform(image_data) - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - ) + image_data = image_data.view(*orig_shape[:3], *self.resize_size) - hidden_states = outputs[0] - - logits = self.lm_head(hidden_states) - logits = logits.float() - - llm_loss = None - - # cross-entropy loss for VLM - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - llm_loss = loss_fct(shift_logits, shift_labels) - - # for evaluation - if is_eval: - loss = None - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) - - action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, - hidden_states=hidden_states) + vl_data = { + 'images': image_data, + 'raw_langs': task_descs, + 'reasonings': reasonings + } + # processing vl_data into qwen2_vl format + vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning) + vla_inputs['states'] = qpos_data + vla_inputs['is_pad'] = is_pad + vla_inputs['actions'] = action_data + return vla_inputs - ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: - loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss, - 'llm_loss': llm_loss, - 'action_loss': ret['loss']} - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + processed_batch = self.process_batch(batch) - torch.cuda.empty_cache() - gc.collect() - del input_ids - del attention_mask - del position_ids - del past_key_values - del inputs_embeds - del labels - del pixel_values - del image_grid_thw - del actions - del states - return Qwen2VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) + ret = self.model.forward(**processed_batch) + loss_dict = ret['loss'] + loss = loss_dict['loss'].mean() + return loss, loss_dict - def film_forward(self, labels, input_ids, hidden_states): - """ - Perform the forward pass for the film module. - """ - inputs_index = labels[:, :] == -100 - inputs_index = inputs_index.int() - - xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:]) - indexs = torch.argmax((xor_array != 0).float(), dim=1) - input_embeddings = [] - reasoning_embeddings = [] - identity = [] - for i in range(indexs.shape[0]): - end = indexs[i] + 1 - temp = input_ids[i] == 151643 # pad token id for qwen2_vl - start = sum(temp.int()) - input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :])) - identity.append(torch.mean(hidden_states[i, start:end, :], dim=0)) - - reasoning_embeddings.append(self.reasoning_action_proj(hidden_states[i, end:, :])) - input_embeddings = torch.cat(input_embeddings, dim=0) - reasoning_embeddings = torch.cat(reasoning_embeddings, dim=0) - identity = torch.stack(identity) - - action_hidden_states = self.reasoning_film(input_embeddings, reasoning_embeddings).unsqueeze(1) - - action_hidden_states = action_hidden_states + identity.unsqueeze(1) - return action_hidden_states - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - pixel_values_videos=None, - image_grid_thw=None, - video_grid_thw=None, - **kwargs, - ): - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0]:] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - rope_deltas = kwargs.get("rope_deltas", None) - if attention_mask is not None and position_ids is None: - if cache_position is None or (cache_position is not None and cache_position[0] == 0): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - else: - batch_size, seq_length = input_ids.shape - delta = ( - cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 - ) - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - if cache_position[0] != 0: - pixel_values = None - pixel_values_videos = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - model_inputs = {"input_ids": input_ids, "inputs_embeds": None} - - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, - "rope_deltas": rope_deltas, - } - ) - model_inputs.update(kwargs) - return model_inputs - - def evaluate(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - tokenizer=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): + def dexvla_predict_action(self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + tokenizer=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): input_ids = input_ids.to('cuda') with torch.inference_mode(): - outputs = self.generate( + outputs = self.model.generate( input_ids, pixel_values=pixel_values, attention_mask=attention_mask, @@ -2010,6 +153,7 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): ) output_ids = outputs.sequences + # last_hidden_states = outputs.hidden_states[-2][-1] input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: @@ -2022,11 +166,94 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin): action_hidden_states = None - if self.using_film: - action_hidden_states = self.film_forward(labels=torch.ones_like(output_ids), + if self.model.using_film: + action_hidden_states = self.model.film_forward(labels=torch.ones_like(output_ids), input_ids=output_ids, hidden_states=torch.cat(last_hidden_states, dim=1)) - action = self.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) + action = self.model.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) return action, outputs_text + def tinyvla_predict_action(self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to('cuda') + with torch.inference_mode(): + all_hidden_states = self.model.forward(input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + is_eval=is_eval, + tinyvla=True) + + all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) + + action = self.model.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad) + return action, "tinyvla generates no reasoning" + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + batch = self.normalize_inputs(batch) + + if len(self._action_queue) == 0: + present_img_keys = [key for key in self.config.image_features if key in batch] + try: + task_descs = batch['task'] + except KeyError: + task_descs = " " + print("No task descriptions found for this task") + + all_cam_images = [] + for k in present_img_keys: + all_cam_images.append(batch[k]) + + # construct observations, and scale 0-1 to 0-255 + image_data = torch.stack(all_cam_images) * 255 + image_data = image_data.to(dtype=torch.uint8) + # construct observations + qpos_data = batch['observation.state'].float() + + image_data = image_data.squeeze(0) + + for transform in self.transformations: + image_data = transform(image_data) + + # processing vl_data into qwen2_vl format + vla_inputs = self.vla_processor.single_forward_process(images=image_data, raw_lang=task_descs, reasoning=None, eval=True) + vla_inputs['states'] = qpos_data + + if self.config.using_film and self.config.with_llm_head: # dexvla + all_actions, outputs = self.dexvla_predict_action(**vla_inputs, is_eval=True, tokenizer=self.tokenizer) + else: # tinyvla + all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True) + + actions = self.unnormalize_outputs({"action": all_actions})["action"] + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + + + + From f96697fb3159c3759e24bfec261d4405636690ad Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Sun, 23 Feb 2025 15:21:39 +0800 Subject: [PATCH 08/36] remove unused code --- .../policies/dexvla/robot_data_processor.py | 154 +++++++++--------- 1 file changed, 74 insertions(+), 80 deletions(-) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 1f90dc79..85e43572 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -8,90 +8,40 @@ from qwen_vl_utils import fetch_image class Qwen2VLAProcess: def __init__( self, - language=None, tokenizer=None, max_seq_len=512, multimodal_processor=None, - camera_names=None, - data_args=None, ): super().__init__() self.tokenizer = tokenizer self.max_seq_len = max_seq_len - self.camera_names = camera_names - # self.language = language self.multimodal_processor = multimodal_processor - self.data_args = data_args - def preprocess_image(self, image, size=224): - # Model has been trained to handle images of different aspects ratios - # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize - # options are helpful to improve quality in some tasks. - image = np.asarray(image) - if image.ndim == 2: # Convert image without last channel into greyscale. - image = np.stack((image,) * 3, axis=-1) - image = image[..., :3] # Remove alpha layer. - assert image.shape[-1] == 3 - - image_pil = to_pil_image(image) - - # Step 2: Define the resize transformation - resize_transform = transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR) - - # Step 3: Apply the resize transformation - image_resized_pil = resize_transform(image_pil) - - # Step 4: Convert back to tensor if needed - image_resized = to_tensor(image_resized_pil) - return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1] - - def qwen2_image_preprocess(self, each, camera_name): + def qwen2_image_preprocess(self, each): ele = {} - each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)) + each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)) ele['image'] = each - if 'wrist' in camera_name: - w, h = eval(self.data_args.image_size_wrist) - ele['resized_height'] = h - ele['resized_width'] = w - else: - ele['resized_height'] = each.height - ele['resized_width'] = each.width + + ele['resized_height'] = each.height + ele['resized_width'] = each.width each = fetch_image(ele) return torch.from_numpy(np.array(each)) - def forward_process(self, sample, use_reasoning=True): - if sample['image'].ndim == 5 and sample['image'].shape[1] > 2: - video = True - else: - video = False - messages = self.datastruct_droid2llava(sample, video=video) + def single_forward_process(self, images, raw_lang, reasoning, eval=False, use_reasoning=True): + len_views = images.shape[0] + messages = self.construct_chat_data(len_views, raw_lang) data_dict = dict( messages=messages, - images=None ) - image_data = torch.chunk(sample['image'], sample['image'].shape[0], 0) + image_data = torch.chunk(images, len_views, 0) images_list = [] for i, each in enumerate(image_data): - if each.ndim == 4: - img_pil = self.qwen2_image_preprocess(each, self.camera_names[i]) - else: - img_pil = [] - for temp in each.squeeze(0): - img_pil.append(self.qwen2_image_preprocess(temp, self.camera_names[i])) - img_pil = torch.stack(img_pil, 0) + img_pil = self.qwen2_image_preprocess(each) images_list.append(img_pil) - # TODO RESIZE - # image_data = image_data / 255.0 - if video: - image_data = None - video_inputs = images_list - else: - image_data = images_list - video_inputs = None text = self.multimodal_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True @@ -99,14 +49,18 @@ class Qwen2VLAProcess: model_inputs = self.multimodal_processor( text=text, - images=image_data, - videos=video_inputs, + images=images_list, + videos=None, padding=True, return_tensors="pt", ) + + if eval: + return model_inputs + input_labels = torch.ones_like(model_inputs['input_ids']) * -100 if use_reasoning: - answer = sample['reasoning'] + "Next action:" + '<|im_end|>' + answer =reasoning + "Next action:" + '<|im_end|>' else: answer = '' + '<|im_end|>' @@ -115,37 +69,77 @@ class Qwen2VLAProcess: model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1) model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1) labels = torch.cat((input_labels, output_labels), dim=-1) - data_dict['state'] = sample['state'] - data_dict['action'] = sample['action'] - data_dict['is_pad'] = sample['is_pad'] + data_dict['labels'] = labels for k, v in model_inputs.items(): data_dict[k] = v return data_dict - def datastruct_droid2llava(self, sample, video=False): - len_image = sample['image'].shape[0] + def forward(self, batch, use_reasoning=True): + """This is the main process function for processing vl data into Qwen2_vl format""" + all_images = batch['images'] + all_images = torch.einsum('v b c h w -> b v c h w', all_images) # camera_views, batch_size, channel, height, width + + ret_l = [] + + for idx, images in enumerate(all_images): + raw_lang = batch['raw_langs'][idx] + reasoning = batch['reasonings'][idx] + ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning) + ret_l.append(ret_dict) + + return self.post_process(ret_l) + + def post_process(self, instances): + input_ids = [torch.flip(instance['input_ids'].squeeze(0), dims=[0]) for instance in instances] + labels = [torch.flip(instance['labels'].squeeze(0), dims=[0]) for instance in instances] + + image_grid_thw = torch.stack([instances['image_grid_thw'] for instances in instances]) + pixel_values = torch.stack([instances['pixel_values'] for instances in instances]) + pixel_values_videos = None + video_grid_thw = None + + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=-100) + labels = torch.flip(labels, dims=[1]) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + input_ids = torch.flip(input_ids, dims=[1]) + b = input_ids.shape[0] + + image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2]) + pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) + + attention_mask = input_ids.ne(self.tokenizer.pad_token_id), + + batch = dict( + input_ids=input_ids, + attention_mask=attention_mask[0], + labels=labels, + image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + pixel_values=pixel_values, + ) + return batch + + def construct_chat_data(self, len_image, raw_lang): messages = [ { "role": "user", "content": [], }, - # {"role": "assistant", "content": f''}, ] for i in range(len_image): - if video: - messages[0]['content'].append({ - "type": "video", - "video": None, - }) - else: - messages[0]['content'].append({ - "type": "image", - "image": None, - }) + messages[0]['content'].append({ + "type": "image", + "image": None, + }) messages[0]['content'].append({"type": "text", "text": f""}) - messages[0]['content'][-1]['text'] = sample['raw_lang'] + messages[0]['content'][-1]['text'] = raw_lang return messages \ No newline at end of file From be5c61ca3904b3d2b075b5cf322bec670694ddfc Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Sun, 23 Feb 2025 15:22:05 +0800 Subject: [PATCH 09/36] update packages requirements for dexvla --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6e7e0575..d5300d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] +dexvla = ["transformers>=4.45.2", "qwen_vl_utils>=0.08", "timm==0.9.10"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", From e3be39426b81ba14b30c03a5cb706daca2e10d8b Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Mon, 24 Feb 2025 15:27:22 +0800 Subject: [PATCH 10/36] update qwen2_vl_path check --- lerobot/common/policies/dexvla/configuration_dexvla.py | 9 ++++++--- lerobot/common/policies/dexvla/modeling_dexvla.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index ae225c27..6f3c0ef0 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -45,7 +45,7 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vla_path: str = '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' + qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' pretrained_path: str = None # pretrained dexvla using_film: bool = True @@ -86,7 +86,10 @@ class DexVLAConfig(PreTrainedConfig): assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`" print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.") else: - print(f"Warning:DexVLA recommend to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.") + print(f"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.") + + if self.qwen2_vl_path is None: + raise ValueError("DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'.") if self.policy_head_type == 'scale_dp_policy': self.policy_head_config = AutoConfig.for_model( @@ -107,7 +110,7 @@ class DexVLAConfig(PreTrainedConfig): else: raise ValueError(f'Policy head type {self.policy_head_type} not supported') - self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vla_path) + self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) def validate_features(self) -> None: # TODO: implement value error diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index bea643c4..e9330a79 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -50,9 +50,9 @@ class DexVLAPolicy(PreTrainedPolicy): self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) self.model.requires_grad_(False) self.model.policy_head.requires_grad_(True) - self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vla_path) + self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path) self.tokenizer = AutoTokenizer.from_pretrained( - config.qwen2_vla_path + config.qwen2_vl_path ) self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format From f7d664dcc0f0b46d12fbfde1aada79c2bc81aa27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 00:35:06 +0000 Subject: [PATCH 11/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/__init__.py | 2 +- .../policies/dexvla/configuration_dexvla.py | 44 ++-- .../common/policies/dexvla/fusion_modules.py | 20 +- .../common/policies/dexvla/modeling_dexvla.py | 167 ++++++------- .../policy_heads/configuration_scaledp.py | 83 ++++--- .../configuration_unet_diffusion.py | 49 ++-- .../dexvla/policy_heads/modeling_scaledp.py | 177 +++++++------- .../policy_heads/modeling_unet_diffusion.py | 219 ++++++++++-------- .../qwe2_vla/configuration_qwen2_vla.py | 19 +- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 191 +++++++++------ .../policies/dexvla/robot_data_processor.py | 84 +++---- lerobot/common/policies/factory.py | 2 +- 12 files changed, 591 insertions(+), 466 deletions(-) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 5a7d1b8a..d212ef7e 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -1,6 +1,6 @@ from .act.configuration_act import ACTConfig as ACTConfig +from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig -from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 6f3c0ef0..b39e74d1 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,30 +13,28 @@ # limitations under the License. """Qwen2VL model configuration""" +from dataclasses import dataclass, field from typing import Tuple -from dataclasses import dataclass, field - from transformers import AutoConfig +from transformers.utils import logging from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) -from transformers.utils import logging from lerobot.configs.policies import PreTrainedConfig -from lerobot.common.policies.dexvla.policy_heads.configuration_scaledp import ScaleDPPolicyConfig -from lerobot.common.policies.dexvla.policy_heads.configuration_unet_diffusion import UnetDiffusionPolicyConfig -from lerobot.common.policies.dexvla.qwe2_vla.configuration_qwen2_vla import Qwen2VLAConfig from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) + + @PreTrainedConfig.register_subclass("dexvla") @dataclass class DexVLAConfig(PreTrainedConfig): # For loading policy head - policy_head_type: str = 'scale_dp_policy' - policy_head_size: str = 'ScaleDP_L' + policy_head_type: str = "scale_dp_policy" + policy_head_size: str = "ScaleDP_L" action_dim: int = 14 state_dim: int = 14 chunk_size: int = 50 @@ -45,9 +42,9 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' + qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' - pretrained_path: str = None # pretrained dexvla + pretrained_path: str = None # pretrained dexvla using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -82,33 +79,37 @@ class DexVLAConfig(PreTrainedConfig): f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) if self.using_reasoning: - assert self.using_film, f"using_reasoning requires `using_film=True`" - assert self.with_llm_head, f"using_reasoning requires `with_llm_head=True`" + assert self.using_film, "using_reasoning requires `using_film=True`" + assert self.with_llm_head, "using_reasoning requires `with_llm_head=True`" print("You have set using_reasoning=True, please make sure your data has key 'reasoning'.") else: - print(f"Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'.") + print( + "Warning:DexVLA recommends to use reasoning data which can better handle long-horizon and dexterous tasks. You can set 'using_reaasoning=True'." + ) if self.qwen2_vl_path is None: - raise ValueError("DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'.") + raise ValueError( + "DexVLA is built on official qwen2_vl-2B. You have to download the official weights of qwen2_vl-2B first and set 'qwen2_vl_path'." + ) - if self.policy_head_type == 'scale_dp_policy': + if self.policy_head_type == "scale_dp_policy": self.policy_head_config = AutoConfig.for_model( model_type=self.policy_head_type, model_size=self.policy_head_size, cond_dim=self.hidden_size, action_dim=self.action_dim, prediction_horizon=self.chunk_size, - state_dim=self.state_dim + state_dim=self.state_dim, ) - elif self.policy_head_type == 'unet_diffusion': + elif self.policy_head_type == "unet_diffusion": self.policy_head_config = AutoConfig.for_model( model_type=self.policy_head_type, global_cond_dim=self.hidden_size, action_dim=self.action_dim, - state_dim=self.state_dim + state_dim=self.state_dim, ) else: - raise ValueError(f'Policy head type {self.policy_head_type} not supported') + raise ValueError(f"Policy head type {self.policy_head_type} not supported") self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) @@ -152,6 +153,3 @@ class DexVLAConfig(PreTrainedConfig): @property def reward_delta_indices(self) -> None: return None - - - diff --git a/lerobot/common/policies/dexvla/fusion_modules.py b/lerobot/common/policies/dexvla/fusion_modules.py index 7eb452e0..0d977edc 100644 --- a/lerobot/common/policies/dexvla/fusion_modules.py +++ b/lerobot/common/policies/dexvla/fusion_modules.py @@ -1,16 +1,18 @@ import torch.nn as nn + class ActionProjector(nn.Module): def __init__(self, in_dim, out_dim=1024): - super(ActionProjector, self).__init__() + super().__init__() self.global_1d_pool = nn.AdaptiveAvgPool1d(1) - self.mlps = nn.ModuleList([ - # nn.LayerNorm(in_dim), - nn.Linear(in_dim, in_dim), - nn.GELU(), - nn.Linear(in_dim, out_dim), - nn.Dropout(0.0), - ] + self.mlps = nn.ModuleList( + [ + # nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.Dropout(0.0), + ] ) def forward(self, x): @@ -22,7 +24,7 @@ class ActionProjector(nn.Module): class FiLM(nn.Module): def __init__(self, feature_dim, condition_dim): - super(FiLM, self).__init__() + super().__init__() self.scale_fc = nn.Linear(condition_dim, feature_dim) self.shift_fc = nn.Linear(condition_dim, feature_dim) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index e9330a79..8734751c 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1,18 +1,16 @@ -import torch -from torch import Tensor - -from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig -from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import ( - Qwen2VLForConditionalGenerationForVLA -) -from lerobot.common.policies.pretrained import PreTrainedPolicy from collections import deque -from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D -from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP -from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess -from transformers import AutoProcessor, AutoTokenizer + +import torch import torchvision.transforms as transforms +from torch import Tensor +from transformers import AutoProcessor, AutoTokenizer + +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig +from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA +from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy + class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -44,17 +42,17 @@ class DexVLAPolicy(PreTrainedPolicy): config.output_features, config.normalization_mapping, dataset_stats ) - for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']: + for k in ["using_film", "llm_loss_weight", "with_llm_head", "policy_head_config"]: setattr(config.qwen2_vla_config, k, config.__dict__[k]) self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) self.model.requires_grad_(False) self.model.policy_head.requires_grad_(True) self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path) - self.tokenizer = AutoTokenizer.from_pretrained( - config.qwen2_vl_path - ) - self.vla_processor = Qwen2VLAProcess(tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor) # process the input data into VLM format + self.tokenizer = AutoTokenizer.from_pretrained(config.qwen2_vl_path) + self.vla_processor = Qwen2VLAProcess( + tokenizer=self.tokenizer, multimodal_processor=self.qwen2_vl_processor + ) # process the input data into VLM format self.resize_size = self.config.resize_size ratio = 0.95 @@ -73,14 +71,14 @@ class DexVLAPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) present_img_keys = [key for key in self.config.image_features if key in batch] - task_descs = batch['task'] + task_descs = batch["task"] try: - reasonings = batch['reasoning'] + reasonings = batch["reasoning"] except KeyError: - reasonings = ['no reasoning'] * len(task_descs) + reasonings = ["no reasoning"] * len(task_descs) pass - is_pad = batch['action_is_pad'] + is_pad = batch["action_is_pad"] all_cam_images = [] for k in present_img_keys: all_cam_images.append(batch[k]) @@ -89,8 +87,8 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = torch.stack(all_cam_images) * 255 image_data = image_data.to(dtype=torch.uint8) # construct observations - qpos_data = batch['observation.state'].float() - action_data = batch['action'].float() + qpos_data = batch["observation.state"].float() + action_data = batch["action"].float() orig_shape = image_data.shape image_data = image_data.view(-1, *orig_shape[2:]) @@ -100,40 +98,35 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = image_data.view(*orig_shape[:3], *self.resize_size) - vl_data = { - 'images': image_data, - 'raw_langs': task_descs, - 'reasonings': reasonings - } + vl_data = {"images": image_data, "raw_langs": task_descs, "reasonings": reasonings} # processing vl_data into qwen2_vl format vla_inputs = self.vla_processor.forward(vl_data, use_reasoning=self.config.using_reasoning) - vla_inputs['states'] = qpos_data - vla_inputs['is_pad'] = is_pad - vla_inputs['actions'] = action_data + vla_inputs["states"] = qpos_data + vla_inputs["is_pad"] = is_pad + vla_inputs["actions"] = action_data return vla_inputs - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: - processed_batch = self.process_batch(batch) ret = self.model.forward(**processed_batch) - loss_dict = ret['loss'] - loss = loss_dict['loss'].mean() + loss_dict = ret["loss"] + loss = loss_dict["loss"].mean() return loss, loss_dict - def dexvla_predict_action(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - tokenizer=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): - input_ids = input_ids.to('cuda') + def dexvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + tokenizer=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to("cuda") with torch.inference_mode(): outputs = self.model.generate( input_ids, @@ -157,7 +150,7 @@ class DexVLAPolicy(PreTrainedPolicy): input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids") outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=False)[0] outputs_text = outputs_text.strip() @@ -167,35 +160,44 @@ class DexVLAPolicy(PreTrainedPolicy): action_hidden_states = None if self.model.using_film: - action_hidden_states = self.model.film_forward(labels=torch.ones_like(output_ids), - input_ids=output_ids, - hidden_states=torch.cat(last_hidden_states, dim=1)) + action_hidden_states = self.model.film_forward( + labels=torch.ones_like(output_ids), + input_ids=output_ids, + hidden_states=torch.cat(last_hidden_states, dim=1), + ) - action = self.model.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad) + action = self.model.policy_head( + actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) return action, outputs_text - def tinyvla_predict_action(self, - input_ids: torch.LongTensor = None, - actions=None, - states=None, - is_pad=None, - is_eval=True, - pixel_values=None, - attention_mask=None, - image_grid_thw=None, - ): - input_ids = input_ids.to('cuda') + def tinyvla_predict_action( + self, + input_ids: torch.LongTensor = None, + actions=None, + states=None, + is_pad=None, + is_eval=True, + pixel_values=None, + attention_mask=None, + image_grid_thw=None, + ): + input_ids = input_ids.to("cuda") with torch.inference_mode(): - all_hidden_states = self.model.forward(input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - image_grid_thw=image_grid_thw, - is_eval=is_eval, - tinyvla=True) + all_hidden_states = self.model.forward( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + image_grid_thw=image_grid_thw, + is_eval=is_eval, + tinyvla=True, + ) all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1) - action = self.model.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad) + action = self.model.policy_head( + actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad + ) return action, "tinyvla generates no reasoning" def reset(self): @@ -219,7 +221,7 @@ class DexVLAPolicy(PreTrainedPolicy): if len(self._action_queue) == 0: present_img_keys = [key for key in self.config.image_features if key in batch] try: - task_descs = batch['task'] + task_descs = batch["task"] except KeyError: task_descs = " " print("No task descriptions found for this task") @@ -232,7 +234,7 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = torch.stack(all_cam_images) * 255 image_data = image_data.to(dtype=torch.uint8) # construct observations - qpos_data = batch['observation.state'].float() + qpos_data = batch["observation.state"].float() image_data = image_data.squeeze(0) @@ -240,20 +242,19 @@ class DexVLAPolicy(PreTrainedPolicy): image_data = transform(image_data) # processing vl_data into qwen2_vl format - vla_inputs = self.vla_processor.single_forward_process(images=image_data, raw_lang=task_descs, reasoning=None, eval=True) - vla_inputs['states'] = qpos_data + vla_inputs = self.vla_processor.single_forward_process( + images=image_data, raw_lang=task_descs, reasoning=None, eval=True + ) + vla_inputs["states"] = qpos_data - if self.config.using_film and self.config.with_llm_head: # dexvla - all_actions, outputs = self.dexvla_predict_action(**vla_inputs, is_eval=True, tokenizer=self.tokenizer) - else: # tinyvla + if self.config.using_film and self.config.with_llm_head: # dexvla + all_actions, outputs = self.dexvla_predict_action( + **vla_inputs, is_eval=True, tokenizer=self.tokenizer + ) + else: # tinyvla all_actions, outputs = self.tinyvla_predict_action(**vla_inputs, is_eval=True) actions = self.unnormalize_outputs({"action": all_actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - - - - - diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 0837f499..6a8f7ea9 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -1,47 +1,58 @@ import os -from typing import Union, List -from transformers import PretrainedConfig +from typing import Union +from transformers import AutoConfig, PretrainedConfig from transformers.utils import logging -from transformers import AutoConfig, AutoModelForCausalLM + logger = logging.get_logger(__name__) MODEL_STRUCTURE = { - 'ScaleDP_H': {'depth': 32, 'n_emb': 1280, 'num_heads': 16, }, - 'ScaleDP_L': {'depth': 24, 'n_emb': 1024, 'num_heads': 16, }, # 400M + "ScaleDP_H": { + "depth": 32, + "n_emb": 1280, + "num_heads": 16, + }, + "ScaleDP_L": { + "depth": 24, + "n_emb": 1024, + "num_heads": 16, + }, # 400M } + class ScaleDPPolicyConfig(PretrainedConfig): - ''' + """ Configuration for ScaleDP policy head - ''' + """ + model_type = "scale_dp_policy" + def __init__( - self, - eval: bool = False, - action_dim: int = 14, # action dim - # output_dim: int = 14, # action dim - cond_dim: int = 1536, # the input dim of the condition - state_dim: int = 14, # the input dim of the state - prediction_horizon: int = 16, # horizon - n_obs_steps: int = 2, # number of observation steps - depth: int = 28, # number of DiT blocks - n_emb: int = 256, # embedding size - num_heads: int = 16, - mlp_ratio: int = 4.0, - time_as_cond: bool = True, - obs_as_cond: bool = True, - learn_sigma: bool = False, - model_size: str = "none", - num_inference_timesteps: int = 10, - noise_samples: int = 1, - num_train_timesteps: int = 100, - **kwargs + self, + eval: bool = False, + action_dim: int = 14, # action dim + # output_dim: int = 14, # action dim + cond_dim: int = 1536, # the input dim of the condition + state_dim: int = 14, # the input dim of the state + prediction_horizon: int = 16, # horizon + n_obs_steps: int = 2, # number of observation steps + depth: int = 28, # number of DiT blocks + n_emb: int = 256, # embedding size + num_heads: int = 16, + mlp_ratio: int = 4.0, + time_as_cond: bool = True, + obs_as_cond: bool = True, + learn_sigma: bool = False, + model_size: str = "none", + num_inference_timesteps: int = 10, + noise_samples: int = 1, + num_train_timesteps: int = 100, + **kwargs, ): if model_size != "none": - depth = MODEL_STRUCTURE[model_size]['depth'] - n_emb = MODEL_STRUCTURE[model_size]['n_emb'] - num_heads = MODEL_STRUCTURE[model_size]['num_heads'] + depth = MODEL_STRUCTURE[model_size]["depth"] + n_emb = MODEL_STRUCTURE[model_size]["n_emb"] + num_heads = MODEL_STRUCTURE[model_size]["num_heads"] else: # raise ValueError("model_size show not be 'none'") pass @@ -52,7 +63,6 @@ class ScaleDPPolicyConfig(PretrainedConfig): self.output_dim = action_dim self.prediction_horizon = prediction_horizon - self.cond_dim = cond_dim self.state_dim = state_dim @@ -72,7 +82,9 @@ class ScaleDPPolicyConfig(PretrainedConfig): super().__init__(**kwargs) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -81,7 +93,11 @@ class ScaleDPPolicyConfig(PretrainedConfig): if config_dict.get("model_type") == "llava_pythia": config_dict = config_dict["action_head"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -89,4 +105,5 @@ class ScaleDPPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) + AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index 38e403a6..aaf66447 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -1,31 +1,33 @@ import os -from typing import Union, List -from transformers import PretrainedConfig +from typing import Union +from transformers import AutoConfig, PretrainedConfig from transformers.utils import logging -from transformers import AutoConfig, AutoModelForCausalLM + logger = logging.get_logger(__name__) + class UnetDiffusionPolicyConfig(PretrainedConfig): - ''' + """ Configuration for dit diffusion policy head - ''' + """ + model_type = "unet_diffusion_policy" def __init__( - self, - action_dim=10, - global_cond_dim=2048, - diffusion_step_embed_dim=256, - down_dims=[256, 512, 1024], - kernel_size=5, - n_groups=8, - state_dim=7, - prediction_horizon=16, - noise_samples=1, - num_inference_timesteps=10, - num_train_timesteps=100, - **kwargs + self, + action_dim=10, + global_cond_dim=2048, + diffusion_step_embed_dim=256, + down_dims=[256, 512, 1024], + kernel_size=5, + n_groups=8, + state_dim=7, + prediction_horizon=16, + noise_samples=1, + num_inference_timesteps=10, + num_train_timesteps=100, + **kwargs, ): self.input_dim = action_dim self.noise_samples = noise_samples @@ -42,7 +44,9 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): super().__init__(**kwargs) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -51,7 +55,11 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): if config_dict.get("model_type") == "llava_pythia": config_dict = config_dict["action_head"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -59,4 +67,5 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) + AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 4c78b6e1..b9a9b919 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -1,27 +1,24 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -from typing import Tuple - -import timm -import numpy as np import logging - import math from typing import Tuple +import numpy as np + try: from typing import Literal except ImportError: - from typing_extensions import Literal + pass import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.jit import Final from timm.models.vision_transformer import Mlp, use_fused_attn +from torch.jit import Final +from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel -from transformers import AutoModel, AutoModelForCausalLM _logger = logging.getLogger(__name__) @@ -30,20 +27,20 @@ class Attention(nn.Module): fused_attn: Final[bool] def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = False, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = nn.LayerNorm, + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -61,8 +58,11 @@ class Attention(nn.Module): if self.fused_attn: x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p if self.training else 0., + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale @@ -104,6 +104,7 @@ def modulate(x, shift, scale): # Embedding Layers for Timesteps and Class Labels # ################################################################################# + class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. @@ -145,11 +146,11 @@ class TimestepEmbedder(nn.Module): return t_emb - ################################################################################# # Core ScaleDP Model # ################################################################################# + class ScaleDPBlock(nn.Module): """ A ScaleDP block with adaptive layer norm zero (adaLN-Zero) conScaleDPioning. @@ -163,14 +164,15 @@ class ScaleDPBlock(nn.Module): mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) def forward(self, x, c, attn_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) # norm, scale&shift, attn, scale, + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk( + 6, dim=1 + ) + x = x + gate_msa.unsqueeze(1) * self.attn( + modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask + ) # norm, scale&shift, attn, scale, x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x @@ -184,10 +186,7 @@ class FinalLayer(nn.Module): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, output_dim, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True) - ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) @@ -195,15 +194,20 @@ class FinalLayer(nn.Module): x = self.linear(x) return x + from .configuration_scaledp import ScaleDPPolicyConfig + + class ScaleDP(PreTrainedModel): """ Diffusion models with a Transformer backbone. """ + config_class = ScaleDPPolicyConfig + def __init__( - self, - config: ScaleDPPolicyConfig, + self, + config: ScaleDPPolicyConfig, ): super().__init__(config) # compute number of tokens for main trunk and conScaleDPion encoder @@ -221,11 +225,11 @@ class ScaleDP(PreTrainedModel): # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) self.combine = nn.Sequential( - nn.Linear(config.cond_dim+config.state_dim, 1024), + nn.Linear(config.cond_dim + config.state_dim, 1024), nn.ReLU(), nn.Linear(1024, 1024), nn.ReLU(), - nn.Linear(1024, config.cond_dim) + nn.Linear(1024, config.cond_dim), ) self.learn_sigma = config.learn_sigma self.input_dim = config.input_dim @@ -241,9 +245,12 @@ class ScaleDP(PreTrainedModel): # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter(torch.zeros(1, config.prediction_horizon, config.n_emb)) - self.blocks = nn.ModuleList([ - ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) for _ in range(config.depth) - ]) + self.blocks = nn.ModuleList( + [ + ScaleDPBlock(config.n_emb, config.num_heads, mlp_ratio=config.mlp_ratio) + for _ in range(config.depth) + ] + ) self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim) # self.initialize_weights() # constants @@ -253,23 +260,22 @@ class ScaleDP(PreTrainedModel): self.time_as_cond = config.time_as_cond self.action_dim = config.output_dim self.obs_as_cond = obs_as_cond - logger.info( - "number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters()) - ) + logger.info("number of parameters in ScaleDP: %e", sum(p.numel() for p in self.parameters())) from diffusers.schedulers.scheduling_ddim import DDIMScheduler + self.num_inference_timesteps = config.num_inference_timesteps # self.proj_to_action = nn.Identity() self.noise_scheduler = DDIMScheduler( - num_train_timesteps=config.num_train_timesteps, # 100 - beta_schedule='squaredcos_cap_v2', + num_train_timesteps=config.num_train_timesteps, # 100 + beta_schedule="squaredcos_cap_v2", clip_sample=True, set_alpha_to_one=True, steps_offset=0, - prediction_type='epsilon' + prediction_type="epsilon", ) - self.num_queries = config.num_queries #16 - self.noise_samples = config.noise_samples # 1 + self.num_queries = config.num_queries # 16 + self.noise_samples = config.noise_samples # 1 # self.num_inference_timesteps = config.num_inference_timesteps # 100 def initialize_weights(self): @@ -308,7 +314,6 @@ class ScaleDP(PreTrainedModel): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def get_optim_groups(self, weight_decay: float = 1e-3): """ This long function is unfortunately doing something very simple and is being very defensive: @@ -324,7 +329,7 @@ class ScaleDP(PreTrainedModel): blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name if pn.endswith("bias"): # all biases will not be decayed @@ -343,13 +348,13 @@ class ScaleDP(PreTrainedModel): param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) - assert ( - len(param_dict.keys() - union_params) == 0 - ), "parameters %s were not separated into either decay/no_decay set!" % ( - str(param_dict.keys() - union_params), + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) ) # create the pytorch optimizer object @@ -365,14 +370,14 @@ class ScaleDP(PreTrainedModel): ] return optim_groups - def configure_optimizers(self, - learning_rate: float = 1e-4, - weight_decay: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.95)): + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): optim_groups = self.get_optim_groups(weight_decay=weight_decay) - optimizer = torch.optim.AdamW( - optim_groups, lr=learning_rate, betas=betas - ) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) return optimizer def forward(self, actions, hidden_states, states, is_pad): @@ -385,25 +390,26 @@ class ScaleDP(PreTrainedModel): """ if actions is not None: # training time B = actions.size(0) - actions = actions[:, :self.num_queries] - is_pad = is_pad[:, :self.num_queries] + actions = actions[:, : self.num_queries] + is_pad = is_pad[:, : self.num_queries] num_noise_samples = self.noise_samples # sample noise to add to actions - noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, - dtype=actions.dtype) # num_noise, B, Ta, D(1, 2, 16, 14) + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, B, Ta, D(1, 2, 16, 14) # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, - (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) # add noise to the clean actions according to the noise magnitude at each diffusion iteration # (this is the forward diffusion process) - noisy_actions = torch.cat([self.noise_scheduler.add_noise( - actions, noise[i], timesteps) - for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * B, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -411,14 +417,16 @@ class ScaleDP(PreTrainedModel): hidden_states = hidden_states.repeat(num_noise_samples, 1, 1) timesteps = timesteps.repeat(num_noise_samples) is_pad = is_pad.repeat(num_noise_samples, 1) - states = states.repeat(num_noise_samples, 1) + states = states.repeat(num_noise_samples, 1) - noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) - loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") loss = (loss * ~is_pad.unsqueeze(-1)).mean() # loss_dict['loss'] = loss - return {'loss': loss} + return {"loss": loss} # return loss else: # inference time B = 1 @@ -438,9 +446,7 @@ class ScaleDP(PreTrainedModel): # inverse diffusion step (remove noise) naction = self.noise_scheduler.step( - model_output=noise_pred, - timestep=k, - sample=naction + model_output=noise_pred, timestep=k, sample=naction ).prev_sample return naction @@ -462,7 +468,9 @@ class ScaleDP(PreTrainedModel): t = t[None].to(x.device) t = t.expand(t.shape[0]) - x = self.x_embedder(x) + self.pos_embed.to(device=x.device, dtype=x.dtype) # (N, T, D), where T = prediction_horizon + x = self.x_embedder(x) + self.pos_embed.to( + device=x.device, dtype=x.dtype + ) # (N, T, D), where T = prediction_horizon t = self.t_embedder(t) # (N, D) if self.obs_as_cond: global_cond = self.cond_obs_emb(global_cond) # (N, D) @@ -474,11 +482,13 @@ class ScaleDP(PreTrainedModel): x = self.final_layer(x, c) # (N, T, output_dim) return x + ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width @@ -516,11 +526,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) @@ -533,12 +543,13 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): # ScaleDP Configs # ################################################################################# + def ScaleDP_H(**kwargs): return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs) + def ScaleDP_L(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) - AutoModel.register(ScaleDPPolicyConfig, ScaleDP) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index a7b456d2..eba83e36 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -1,29 +1,29 @@ """ Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi """ -from typing import Callable, Union + +import copy import math -from collections import OrderedDict, deque -from packaging.version import parse as parse_version -import random +from typing import Union + import torch import torch.nn as nn -import torch.nn.functional as F + # requires diffusers==0.11.1 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from diffusers.training_utils import EMAModel -from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel -from transformers import AutoModel, AutoModelForCausalLM -import copy + +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig + # =================== UNet for Diffusion ============== + class SinusoidalPosEmb(nn.Module): def __init__(self, dim, dtype): super().__init__() self.dim = dim - self.dtype=dtype + self.dtype = dtype def forward(self, x): device = x.device @@ -54,9 +54,9 @@ class Upsample1d(nn.Module): class Conv1dBlock(nn.Module): - ''' - Conv1d --> GroupNorm --> Mish - ''' + """ + Conv1d --> GroupNorm --> Mish + """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() @@ -72,46 +72,41 @@ class Conv1dBlock(nn.Module): class ConditionalResidualBlock1D(nn.Module): - def __init__(self, - in_channels, - out_channels, - cond_dim, - kernel_size=3, - n_groups=8): + def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): super().__init__() - self.blocks = nn.ModuleList([ - Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), - Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), - ]) + self.blocks = nn.ModuleList( + [ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ] + ) # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias cond_channels = out_channels * 2 self.out_channels = out_channels self.cond_encoder = nn.Sequential( - nn.Mish(), - nn.Linear(cond_dim, cond_channels), - nn.Unflatten(-1, (-1, 1)) + nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) ) # make sure dimensions compatible - self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ - if in_channels != out_channels else nn.Identity() + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) def forward(self, x, cond): - ''' - x : [ batch_size x in_channels x horizon ] - cond : [ batch_size x cond_dim] + """ + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] - returns: - out : [ batch_size x out_channels x horizon ] - ''' + returns: + out : [ batch_size x out_channels x horizon ] + """ out = self.blocks[0](x) embed = self.cond_encoder(cond) - embed = embed.reshape( - embed.shape[0], 2, self.out_channels, 1) + embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) scale = embed[:, 0, ...] bias = embed[:, 1, ...] out = scale * out + bias @@ -125,9 +120,8 @@ class ConditionalUnet1D(PreTrainedModel): _no_split_modules = ["mid_modules", "down_modules", "up_modules"] config_class = UnetDiffusionPolicyConfig - def __init__(self, - config: UnetDiffusionPolicyConfig - ): + + def __init__(self, config: UnetDiffusionPolicyConfig): """ input_dim: Dim of actions. global_cond_dim: Dim of global conditioning applied with FiLM @@ -148,7 +142,7 @@ class ConditionalUnet1D(PreTrainedModel): # self.global_1d_pool = nn.AdaptiveAvgPool1d(1) # self.proj2action = nn.Linear(config.hidden_dim, config.global_cond_dim) self.norm_after_pool = nn.LayerNorm(config.global_cond_dim) - self.combine = nn.Linear(config.global_cond_dim+config.state_dim, config.global_cond_dim) + self.combine = nn.Linear(config.global_cond_dim + config.state_dim, config.global_cond_dim) dsed = config.diffusion_step_embed_dim diffusion_step_encoder = nn.Sequential( SinusoidalPosEmb(dsed, torch.bfloat16), @@ -158,44 +152,76 @@ class ConditionalUnet1D(PreTrainedModel): ) cond_dim = dsed + config.global_cond_dim - in_out = list(zip(all_dims[:-1], all_dims[1:])) + in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) mid_dim = all_dims[-1] - self.mid_modules = nn.ModuleList([ - ConditionalResidualBlock1D( - mid_dim, mid_dim, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups - ), - ConditionalResidualBlock1D( - mid_dim, mid_dim, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups - ), - ]) + self.mid_modules = nn.ModuleList( + [ + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ] + ) down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) - down_modules.append(nn.ModuleList([ - ConditionalResidualBlock1D( - dim_in, dim_out, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - ConditionalResidualBlock1D( - dim_out, dim_out, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - Downsample1d(dim_out) if not is_last else nn.Identity() - ])) + down_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Downsample1d(dim_out) if not is_last else nn.Identity(), + ] + ) + ) up_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) - up_modules.append(nn.ModuleList([ - ConditionalResidualBlock1D( - dim_out * 2, dim_in, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - ConditionalResidualBlock1D( - dim_in, dim_in, cond_dim=cond_dim, - kernel_size=config.kernel_size, n_groups=config.n_groups), - Upsample1d(dim_in) if not is_last else nn.Identity() - ])) + up_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_out * 2, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + ConditionalResidualBlock1D( + dim_in, + dim_in, + cond_dim=cond_dim, + kernel_size=config.kernel_size, + n_groups=config.n_groups, + ), + Upsample1d(dim_in) if not is_last else nn.Identity(), + ] + ) + ) final_conv = nn.Sequential( Conv1dBlock(start_dim, start_dim, kernel_size=config.kernel_size), @@ -207,20 +233,17 @@ class ConditionalUnet1D(PreTrainedModel): self.down_modules = down_modules self.final_conv = final_conv - print("number of parameters: {:e}".format( - sum(p.numel() for p in self.parameters())) - ) + print("number of parameters: {:e}".format(sum(p.numel() for p in self.parameters()))) - from diffusers.schedulers.scheduling_ddim import DDIMScheduler self.num_inference_timesteps = config.num_inference_timesteps # self.proj_to_action = nn.Identity() self.noise_scheduler = DDIMScheduler( num_train_timesteps=config.num_train_timesteps, # 100 - beta_schedule='squaredcos_cap_v2', + beta_schedule="squaredcos_cap_v2", clip_sample=True, set_alpha_to_one=True, steps_offset=0, - prediction_type='epsilon' + prediction_type="epsilon", ) # self.num_inference_timesteps = config.num_inference_timesteps # 100 @@ -235,25 +258,26 @@ class ConditionalUnet1D(PreTrainedModel): """ if actions is not None: # training time B = actions.size(0) - actions = copy.deepcopy(actions[:, :self.num_queries]) - is_pad = copy.deepcopy(is_pad[:, :self.num_queries]) + actions = copy.deepcopy(actions[:, : self.num_queries]) + is_pad = copy.deepcopy(is_pad[:, : self.num_queries]) num_noise_samples = self.noise_samples # sample noise to add to actions - noise = torch.randn([num_noise_samples] + list(actions.shape), device=actions.device, - dtype=actions.dtype) # num_noise, B, Ta, D + noise = torch.randn( + [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype + ) # num_noise, B, Ta, D # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, - (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) # add noise to the clean actions according to the noise magnitude at each diffusion iteration # (this is the forward diffusion process) - noisy_actions = torch.cat([self.noise_scheduler.add_noise( - actions, noise[i], timesteps) - for i in range(len(noise))], dim=0) # [num_noise_samples * B, Ta, action_dim] + noisy_actions = torch.cat( + [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], + dim=0, + ) # [num_noise_samples * B, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -263,12 +287,14 @@ class ConditionalUnet1D(PreTrainedModel): is_pad = is_pad.repeat(num_noise_samples, 1) states = states.repeat(num_noise_samples, 1) - noise_pred = self.model_forward(noisy_actions, timesteps, global_cond=hidden_states, states=states) + noise_pred = self.model_forward( + noisy_actions, timesteps, global_cond=hidden_states, states=states + ) noise = noise.view(noise.size(0) * noise.size(1), *noise.size()[2:]) - loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction='none') + loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none") loss = (loss * ~is_pad.unsqueeze(-1)).mean() # loss_dict['loss'] = loss - return {'loss': loss} + return {"loss": loss} # return loss else: # inference time B = 1 @@ -288,18 +314,14 @@ class ConditionalUnet1D(PreTrainedModel): # inverse diffusion step (remove noise) naction = self.noise_scheduler.step( - model_output=noise_pred, - timestep=k, - sample=naction + model_output=noise_pred, timestep=k, sample=naction ).prev_sample return naction - def model_forward(self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - global_cond=None, - states=None): + def model_forward( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None + ): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step @@ -327,9 +349,7 @@ class ConditionalUnet1D(PreTrainedModel): global_feature = self.diffusion_step_encoder(timesteps) if global_cond is not None: - global_feature = torch.cat([ - global_feature, global_cond - ], axis=-1) + global_feature = torch.cat([global_feature, global_cond], axis=-1) x = sample h = [] @@ -355,4 +375,5 @@ class ConditionalUnet1D(PreTrainedModel): # (B,T,C) return x + AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index a1a1d81f..f6b46350 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +16,10 @@ import os from typing import Union +from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging -from transformers import AutoModel, AutoConfig logger = logging.get_logger(__name__) @@ -56,7 +55,9 @@ class Qwen2VLVisionConfig(PretrainedConfig): self.temporal_patch_size = temporal_patch_size @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -64,7 +65,11 @@ class Qwen2VLVisionConfig(PretrainedConfig): if config_dict.get("model_type") == "qwen2_vl": config_dict = config_dict["vision_config"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." @@ -204,7 +209,7 @@ class Qwen2VLAConfig(PretrainedConfig): vision_config=None, rope_scaling=None, # For loading policy head - policy_head_type='scale_dp_policy', # unet_diffusion_policy + policy_head_type="scale_dp_policy", # unet_diffusion_policy **kwargs, ): if isinstance(vision_config, dict): @@ -221,7 +226,7 @@ class Qwen2VLAConfig(PretrainedConfig): self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.max_window_layers = max_window_layers - self.policy_head_type = policy_head_type # for loading policy head + self.policy_head_type = policy_head_type # for loading policy head # for backward compatibility if num_key_value_heads is None: @@ -248,5 +253,5 @@ class Qwen2VLAConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -from transformers import AutoConfig + AutoConfig.register("qwen2_vla", Qwen2VLAConfig) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index e37fea19..235c66a3 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -19,6 +18,7 @@ # limitations under the License. """PyTorch Qwen2-VL model.""" +import gc import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm - +from transformers import AutoConfig, AutoModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin @@ -37,8 +37,6 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ModelOutput, ) -from lerobot.common.policies.dexvla.fusion_modules import * - from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( @@ -49,14 +47,13 @@ from transformers.utils import ( logging, replace_return_docstrings, ) -from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig -from transformers import AutoConfig, AutoModel -import gc +from lerobot.common.policies.dexvla.fusion_modules import * + +from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func - from transformers.modeling_flash_attention_utils import _flash_attention_forward else: flash_attn_varlen_func = None @@ -161,10 +158,12 @@ class Qwen2VLRotaryEmbedding(nn.Module): inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + if ( + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len + ): # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -335,7 +334,9 @@ class VisionAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -369,7 +370,9 @@ class VisionFlashAttention2(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -392,7 +395,9 @@ class VisionSdpaAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -538,7 +543,9 @@ class Qwen2VLAttention(nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -569,8 +576,14 @@ class Qwen2VLAttention(nn.Module): ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -585,7 +598,9 @@ class Qwen2VLAttention(nn.Module): # Fix precision issues in Qwen2-VL float16 inference # Replace inf values with zeros in attention weights to prevent NaN propagation if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -634,7 +649,9 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ): bsz, q_len, _ = hidden_states.size() @@ -696,10 +713,18 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + attention_mask = torch.cat( + [attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1 + ) - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -781,7 +806,9 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: logger.warning_once( @@ -826,8 +853,14 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -897,7 +930,9 @@ class Qwen2VLDecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1116,7 +1151,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1208,7 +1245,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1436,6 +1475,7 @@ QWEN2_VL_INPUTS_DOCSTRING = r""" The rope index difference between sequence length and multimodal rope. """ + class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1599,9 +1639,15 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + t_index = ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + ) + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w @@ -1721,18 +1767,20 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ```""" self.computed_type = torch.bfloat16 - input_ids=input_ids.to("cuda") - attention_mask=attention_mask.to("cuda") + input_ids = input_ids.to("cuda") + attention_mask = attention_mask.to("cuda") if not is_eval: labels = labels.to("cuda") - actions = actions.to(dtype=self.computed_type, device='cuda') - states = states.to(dtype=self.computed_type, device='cuda') + actions = actions.to(dtype=self.computed_type, device="cuda") + states = states.to(dtype=self.computed_type, device="cuda") position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask ) if pixel_values is not None: - pixel_values = pixel_values.to(dtype=self.computed_type, device='cuda') - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda") + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1792,7 +1840,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ) hidden_states = outputs[0] - if tinyvla: # dex-vla supports tinyvla-style VLA + if tinyvla: # dex-vla supports tinyvla-style VLA return hidden_states if self.with_llm_head: @@ -1831,23 +1879,30 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi attentions=outputs.attentions, rope_deltas=rope_deltas, ) - + if self.using_film: - action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids, - hidden_states=hidden_states) - else: # tinyvla + action_hidden_states = self.film_forward( + labels=labels, input_ids=input_ids, hidden_states=hidden_states + ) + else: # tinyvla action_hidden_states = hidden_states - ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad) + ret = self.policy_head( + actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad + ) if self.with_llm_head: - loss = {'loss': ret['loss'] + self.llm_loss_weight * llm_loss, - 'llm_loss': llm_loss, - 'action_loss': ret['loss']} + loss = { + "loss": ret["loss"] + self.llm_loss_weight * llm_loss, + "llm_loss": llm_loss, + "action_loss": ret["loss"], + } else: - loss = {'loss': ret['loss'], - 'llm_loss': (torch.ones(1)*(-100)).to(ret['loss'].dtype).squeeze(0), - 'action_loss': ret['loss']} + loss = { + "loss": ret["loss"], + "llm_loss": (torch.ones(1) * (-100)).to(ret["loss"].dtype).squeeze(0), + "action_loss": ret["loss"], + } if not return_dict: output = (logits,) + outputs[1:] @@ -1904,30 +1959,32 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return action_hidden_states def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - pixel_values=None, - pixel_values_videos=None, - image_grid_thw=None, - video_grid_thw=None, - **kwargs, + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0]:] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - rope_deltas = kwargs.get("rope_deltas", None) + rope_deltas = kwargs.get("rope_deltas") if attention_mask is not None and position_ids is None: if cache_position is None or (cache_position is not None and cache_position[0] == 0): position_ids, rope_deltas = self.get_rope_index( @@ -1936,7 +1993,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi else: batch_size, seq_length = input_ids.shape delta = ( - cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + cache_position[0] + rope_deltas + if cache_position is not None and rope_deltas is not None + else 0 ) position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) @@ -1990,6 +2049,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs - from transformers import AutoModelForCausalLM + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 85e43572..08ede3e0 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -1,16 +1,15 @@ -from PIL import Image import numpy as np -from torchvision.transforms.functional import to_pil_image, to_tensor -import torchvision.transforms as transforms import torch -from qwen_vl_utils import process_vision_info +from PIL import Image from qwen_vl_utils import fetch_image + + class Qwen2VLAProcess: def __init__( - self, - tokenizer=None, - max_seq_len=512, - multimodal_processor=None, + self, + tokenizer=None, + max_seq_len=512, + multimodal_processor=None, ): super().__init__() self.tokenizer = tokenizer @@ -20,10 +19,10 @@ class Qwen2VLAProcess: def qwen2_image_preprocess(self, each): ele = {} each = Image.fromarray(each.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)) - ele['image'] = each + ele["image"] = each - ele['resized_height'] = each.height - ele['resized_width'] = each.width + ele["resized_height"] = each.height + ele["resized_width"] = each.width each = fetch_image(ele) return torch.from_numpy(np.array(each)) @@ -58,61 +57,63 @@ class Qwen2VLAProcess: if eval: return model_inputs - input_labels = torch.ones_like(model_inputs['input_ids']) * -100 + input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 if use_reasoning: - answer =reasoning + "Next action:" + '<|im_end|>' + answer = reasoning + "Next action:" + "<|im_end|>" else: - answer = '' + '<|im_end|>' + answer = "" + "<|im_end|>" output_text = self.tokenizer(answer, padding=True, return_tensors="pt") - output_labels = output_text['input_ids'] - model_inputs['input_ids'] = torch.cat((model_inputs['input_ids'], output_text['input_ids']), dim=-1) - model_inputs['attention_mask'] = torch.cat((model_inputs['attention_mask'], output_text['attention_mask']), dim=-1) + output_labels = output_text["input_ids"] + model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], output_text["input_ids"]), dim=-1) + model_inputs["attention_mask"] = torch.cat( + (model_inputs["attention_mask"], output_text["attention_mask"]), dim=-1 + ) labels = torch.cat((input_labels, output_labels), dim=-1) - data_dict['labels'] = labels + data_dict["labels"] = labels for k, v in model_inputs.items(): data_dict[k] = v return data_dict def forward(self, batch, use_reasoning=True): """This is the main process function for processing vl data into Qwen2_vl format""" - all_images = batch['images'] - all_images = torch.einsum('v b c h w -> b v c h w', all_images) # camera_views, batch_size, channel, height, width + all_images = batch["images"] + all_images = torch.einsum( + "v b c h w -> b v c h w", all_images + ) # camera_views, batch_size, channel, height, width ret_l = [] for idx, images in enumerate(all_images): - raw_lang = batch['raw_langs'][idx] - reasoning = batch['reasonings'][idx] + raw_lang = batch["raw_langs"][idx] + reasoning = batch["reasonings"][idx] ret_dict = self.single_forward_process(images, raw_lang, reasoning, use_reasoning=use_reasoning) ret_l.append(ret_dict) return self.post_process(ret_l) def post_process(self, instances): - input_ids = [torch.flip(instance['input_ids'].squeeze(0), dims=[0]) for instance in instances] - labels = [torch.flip(instance['labels'].squeeze(0), dims=[0]) for instance in instances] + input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] + labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] - image_grid_thw = torch.stack([instances['image_grid_thw'] for instances in instances]) - pixel_values = torch.stack([instances['pixel_values'] for instances in instances]) + image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances]) + pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) pixel_values_videos = None video_grid_thw = None - labels = torch.nn.utils.rnn.pad_sequence(labels, - batch_first=True, - padding_value=-100) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) labels = torch.flip(labels, dims=[1]) - input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) input_ids = torch.flip(input_ids, dims=[1]) b = input_ids.shape[0] image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2]) pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) - attention_mask = input_ids.ne(self.tokenizer.pad_token_id), + attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) batch = dict( input_ids=input_ids, @@ -126,7 +127,6 @@ class Qwen2VLAProcess: return batch def construct_chat_data(self, len_image, raw_lang): - messages = [ { "role": "user", @@ -135,11 +135,13 @@ class Qwen2VLAProcess: ] for i in range(len_image): - messages[0]['content'].append({ - "type": "image", - "image": None, - }) - messages[0]['content'].append({"type": "text", "text": f""}) - messages[0]['content'][-1]['text'] = raw_lang + messages[0]["content"].append( + { + "type": "image", + "image": None, + } + ) + messages[0]["content"].append({"type": "text", "text": ""}) + messages[0]["content"][-1]["text"] = raw_lang - return messages \ No newline at end of file + return messages diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index e7777367..299877a6 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -24,9 +24,9 @@ from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.envs.configs import EnvConfig from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig +from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.pi0.configuration_pi0 import PI0Config -from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig From 9f4d4904234c52715cc0fc12e08ad3da7cbc43d5 Mon Sep 17 00:00:00 2001 From: wk Date: Tue, 11 Mar 2025 13:18:08 +0800 Subject: [PATCH 12/36] cyf_update --- .../common/policies/dexvla/modeling_dexvla.py | 8 +-- .../dexvla/policy_heads/modeling_scaledp.py | 2 +- .../qwe2_vla/configuration_qwen2_vla.py | 2 +- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 70 +++++++++---------- .../policies/dexvla/robot_data_processor.py | 10 +-- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 8734751c..0184ccb7 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -124,7 +124,7 @@ class DexVLAPolicy(PreTrainedPolicy): is_eval=True, pixel_values=None, attention_mask=None, - image_grid_thw=None, + image_grid_spatiotemporal=None, ): input_ids = input_ids.to("cuda") with torch.inference_mode(): @@ -132,7 +132,7 @@ class DexVLAPolicy(PreTrainedPolicy): input_ids, pixel_values=pixel_values, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + image_grid_spatiotemporal=image_grid_spatiotemporal, is_eval=is_eval, num_beams=1, do_sample=False, @@ -180,7 +180,7 @@ class DexVLAPolicy(PreTrainedPolicy): is_eval=True, pixel_values=None, attention_mask=None, - image_grid_thw=None, + image_grid_spatiotemporal=None, ): input_ids = input_ids.to("cuda") with torch.inference_mode(): @@ -188,7 +188,7 @@ class DexVLAPolicy(PreTrainedPolicy): input_ids, pixel_values=pixel_values, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + image_grid_spatiotemporal=image_grid_spatiotemporal, is_eval=is_eval, tinyvla=True, ) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index b9a9b919..0614a3b6 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -433,7 +433,7 @@ class ScaleDP(PreTrainedModel): Tp = self.num_queries action_dim = self.action_dim - # initialize action from Guassian noise + # initialize action from Gaussian noise noisy_action = torch.randn((B, Tp, action_dim)).cuda() naction = noisy_action.to(dtype=hidden_states.dtype) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index f6b46350..80717bc2 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -243,7 +243,7 @@ class Qwen2VLAConfig(PretrainedConfig): # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. - # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations # one can set it to "linear"/"dynamic" etc. to have scaled RoPE if self.rope_scaling is not None and "type" in self.rope_scaling: if self.rope_scaling["type"] == "mrope": diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 235c66a3..341887ff 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -172,7 +172,7 @@ class Qwen2VLRotaryEmbedding(nn.Module): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for spatiotemporal grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) @@ -206,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim Explanation: Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, height and width) of text embedding is always the same, so the text embedding rotary position embedding has no @@ -636,7 +636,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() @@ -1066,9 +1066,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): def get_device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device - def rot_pos_emb(self, grid_thw): + def rot_pos_emb(self, grid_spatiotemporal): pos_ids = [] - for t, h, w in grid_thw: + for t, h, w in grid_spatiotemporal: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, @@ -1090,16 +1090,16 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() + max_grid_size = grid_spatiotemporal[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_spatiotemporal: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum( dim=0, dtype=torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) @@ -1358,7 +1358,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): - The device to plcae the 4D attention mask on. + The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1467,9 +1467,9 @@ QWEN2_VL_INPUTS_DOCSTRING = r""" The tensors corresponding to the input videos. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses [`Qwen2VLImageProcessor`] for processing videos. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. @@ -1531,8 +1531,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi def get_rope_index( self, input_ids: torch.LongTensor, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, + image_grid_spatiotemporal: Optional[torch.LongTensor] = None, + video_grid_spatiotemporal: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -1541,7 +1541,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi Explanation: Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] @@ -1565,9 +1565,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -1584,7 +1584,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] - if image_grid_thw is not None or video_grid_thw is not None: + if image_grid_spatiotemporal is not None or video_grid_spatiotemporal is not None: total_input_ids = input_ids position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device @@ -1613,18 +1613,18 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], + image_grid_spatiotemporal[image_index][0], + image_grid_spatiotemporal[image_index][1], + image_grid_spatiotemporal[image_index][2], ) image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], + video_grid_spatiotemporal[video_index][0], + video_grid_spatiotemporal[video_index][1], + video_grid_spatiotemporal[video_index][2], ) video_index += 1 remain_videos -= 1 @@ -1717,8 +1717,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, + image_grid_spatiotemporal: Optional[torch.LongTensor] = None, + video_grid_spatiotemporal: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, actions: Optional[torch.LongTensor] = None, states: Optional[torch.FloatTensor] = None, @@ -1774,7 +1774,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi actions = actions.to(dtype=self.computed_type, device="cuda") states = states.to(dtype=self.computed_type, device="cuda") position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask + input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask ) if pixel_values is not None: pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda") @@ -1790,7 +1790,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.visual(pixel_values, grid_spatiotemporal=image_grid_spatiotemporal) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: @@ -1808,7 +1808,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_embeds = self.visual(pixel_values_videos, grid_spatiotemporal=video_grid_spatiotemporal) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: @@ -1917,7 +1917,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi del inputs_embeds del labels del pixel_values - del image_grid_thw + del image_grid_spatiotemporal del actions del states return Qwen2VLCausalLMOutputWithPast( @@ -1969,8 +1969,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi use_cache=True, pixel_values=None, pixel_values_videos=None, - image_grid_thw=None, - video_grid_thw=None, + image_grid_spatiotemporal=None, + video_grid_spatiotemporal=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1988,7 +1988,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi if attention_mask is not None and position_ids is None: if cache_position is None or (cache_position is not None and cache_position[0] == 0): position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask + input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask ) else: batch_size, seq_length = input_ids.shape @@ -2040,8 +2040,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi "attention_mask": attention_mask, "pixel_values": pixel_values, "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, + "image_grid_spatiotemporal": image_grid_spatiotemporal, + "video_grid_spatiotemporal": video_grid_spatiotemporal, "rope_deltas": rope_deltas, } ) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 08ede3e0..81988998 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -97,10 +97,10 @@ class Qwen2VLAProcess: input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] - image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances]) + image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances]) pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) pixel_values_videos = None - video_grid_thw = None + video_grid_spatiotemporal = None labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) labels = torch.flip(labels, dims=[1]) @@ -110,7 +110,7 @@ class Qwen2VLAProcess: input_ids = torch.flip(input_ids, dims=[1]) b = input_ids.shape[0] - image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2]) + image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]) pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) @@ -119,9 +119,9 @@ class Qwen2VLAProcess: input_ids=input_ids, attention_mask=attention_mask[0], labels=labels, - image_grid_thw=image_grid_thw, + image_grid_spatiotemporal=image_grid_spatiotemporal, pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thw, + video_grid_spatiotemporal=video_grid_spatiotemporal, pixel_values=pixel_values, ) return batch From 5dd5fc95b55313634b807aa0c9f6c57067724e05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 05:18:46 +0000 Subject: [PATCH 13/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py | 6 +++--- lerobot/common/policies/dexvla/robot_data_processor.py | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 341887ff..7e4beb5b 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -1099,9 +1099,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal) - cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) + cu_seqlens = torch.repeat_interleave( + grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 81988998..4b75c439 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -97,7 +97,9 @@ class Qwen2VLAProcess: input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] - image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances]) + image_grid_spatiotemporal = torch.stack( + [instances["image_grid_spatiotemporal"] for instances in instances] + ) pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) pixel_values_videos = None video_grid_spatiotemporal = None @@ -110,7 +112,9 @@ class Qwen2VLAProcess: input_ids = torch.flip(input_ids, dims=[1]) b = input_ids.shape[0] - image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]) + image_grid_spatiotemporal = image_grid_spatiotemporal.reshape( + b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2] + ) pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) From d927a907622254ca92c100515def03c31c669c75 Mon Sep 17 00:00:00 2001 From: wk Date: Tue, 11 Mar 2025 14:02:30 +0800 Subject: [PATCH 14/36] fix modeling --- .../configuration_unet_diffusion.py | 4 +- .../dexvla/policy_heads/modeling_scaledp.py | 42 +++++++++---------- .../policy_heads/modeling_unet_diffusion.py | 4 +- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 6 +-- 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index aaf66447..6ca6fcbe 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -19,7 +19,7 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): action_dim=10, global_cond_dim=2048, diffusion_step_embed_dim=256, - down_dims=[256, 512, 1024], + down_dims=None, kernel_size=5, n_groups=8, state_dim=7, @@ -29,6 +29,8 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): num_train_timesteps=100, **kwargs, ): + if down_dims is None: + down_dims = [256, 512, 1024] self.input_dim = action_dim self.noise_samples = noise_samples self.prediction_horizon = prediction_horizon diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 0614a3b6..5678625e 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -6,14 +6,9 @@ from typing import Tuple import numpy as np -try: - from typing import Literal -except ImportError: - pass - import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as Func import torch.utils.checkpoint from timm.models.vision_transformer import Mlp, use_fused_attn from torch.jit import Final @@ -51,13 +46,13 @@ class Attention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: - x = F.scaled_dot_product_attention( + x = Func.scaled_dot_product_attention( q, k, v, @@ -79,7 +74,7 @@ class Attention(nn.Module): attn_scores += attn_mask # Apply softmax to get attention weights (softmax is applied along the last dimension) - attn_weights = F.softmax(attn_scores, dim=-1) + attn_weights = Func.softmax(attn_scores, dim=-1) # Dropout on attention weights (if dropout is used) attn_weights = self.attn_drop(attn_weights) @@ -87,7 +82,7 @@ class Attention(nn.Module): # Apply attention weights to value tensor (V) x = torch.matmul(attn_weights, v) - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(b, n, c) x = self.proj(x) x = self.proj_drop(x) return x @@ -162,7 +157,8 @@ class ScaleDPBlock(nn.Module): self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) - approx_gelu = lambda: nn.GELU(approximate="tanh") + def approx_gelu(): + return nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) @@ -213,15 +209,15 @@ class ScaleDP(PreTrainedModel): # compute number of tokens for main trunk and conScaleDPion encoder if config.n_obs_steps is None: config.n_obs_steps = config.prediction_horizon - T = config.prediction_horizon - T_cond = 1 + t = config.prediction_horizon + t_cond = 1 if not config.time_as_cond: - T += 1 - T_cond -= 1 + t += 1 + t_cond -= 1 obs_as_cond = config.cond_dim > 0 if obs_as_cond: assert config.time_as_cond - T_cond += config.n_obs_steps + t_cond += config.n_obs_steps # self.combine = nn.Linear(cond_dim+state_dim, cond_dim) self.combine = nn.Sequential( @@ -254,8 +250,8 @@ class ScaleDP(PreTrainedModel): self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim) # self.initialize_weights() # constants - self.T = T - self.T_cond = T_cond + self.t = t + self.t_cond = t_cond self.prediction_horizon = config.prediction_horizon self.time_as_cond = config.time_as_cond self.action_dim = config.output_dim @@ -328,7 +324,7 @@ class ScaleDP(PreTrainedModel): whitelist_weight_modules = (torch.nn.Linear, Attention) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): + for pn, _p in m.named_parameters(): fpn = "{}.{}".format(mn, pn) if mn else pn # full param name if pn.endswith("bias"): @@ -345,7 +341,7 @@ class ScaleDP(PreTrainedModel): no_decay.add(fpn) # validate that we considered every parameter - param_dict = {pn: p for pn, p in self.named_parameters()} + param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( @@ -360,11 +356,11 @@ class ScaleDP(PreTrainedModel): # create the pytorch optimizer object optim_groups = [ { - "params": [param_dict[pn] for pn in sorted(list(decay))], + "params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": weight_decay, }, { - "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0, }, ] diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index eba83e36..210f6ba5 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -128,7 +128,7 @@ class ConditionalUnet1D(PreTrainedModel): in addition to diffusion step embedding. This is usually obs_horizon * obs_dim diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k down_dims: Channel size for each UNet level. - The length of this array determines numebr of levels. + The length of this array determines number of levels. kernel_size: Conv kernel size n_groups: Number of groups for GroupNorm """ @@ -301,7 +301,7 @@ class ConditionalUnet1D(PreTrainedModel): Tp = self.num_queries action_dim = 14 - # initialize action from Guassian noise + # initialize action from Gaussian noise noisy_action = torch.randn((B, Tp, action_dim)).cuda() naction = noisy_action.to(dtype=hidden_states.dtype) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 341887ff..aef41381 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -1937,12 +1937,12 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi inputs_index = inputs_index.int() xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:]) - indexs = torch.argmax((xor_array != 0).float(), dim=1) + indexes = torch.argmax((xor_array != 0).float(), dim=1) input_embeddings = [] reasoning_embeddings = [] identity = [] - for i in range(indexs.shape[0]): - end = indexs[i] + 1 + for i in range(indexes.shape[0]): + end = indexes[i] + 1 temp = input_ids[i] == 151643 # pad token id for qwen2_vl start = sum(temp.int()) input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :])) From bf954dc715b377207ec59471083e2de226d01f94 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 06:03:00 +0000 Subject: [PATCH 15/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/policies/dexvla/policy_heads/modeling_scaledp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 5678625e..4f6f3237 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -5,7 +5,6 @@ import math from typing import Tuple import numpy as np - import torch import torch.nn as nn import torch.nn.functional as Func @@ -157,8 +156,10 @@ class ScaleDPBlock(nn.Module): self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) + def approx_gelu(): return nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) From e2a41716a431120447b6b0bdd9d31b65042371bc Mon Sep 17 00:00:00 2001 From: wk Date: Tue, 11 Mar 2025 14:31:12 +0800 Subject: [PATCH 16/36] cyf_fix --- .../policy_heads/configuration_scaledp.py | 4 +- .../dexvla/policy_heads/modeling_scaledp.py | 32 +++++++-------- .../policy_heads/modeling_unet_diffusion.py | 40 +++++++++---------- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 40 +++++++++---------- 4 files changed, 57 insertions(+), 59 deletions(-) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 6a8f7ea9..385c8dc1 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -7,12 +7,12 @@ from transformers.utils import logging logger = logging.get_logger(__name__) MODEL_STRUCTURE = { - "ScaleDP_H": { + "scaledp_h": { "depth": 32, "n_emb": 1280, "num_heads": 16, }, - "ScaleDP_L": { + "scaledp_l": { "depth": 24, "n_emb": 1024, "num_heads": 16, diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 5678625e..ed853106 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -8,12 +8,13 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as Func +import torch.nn.functional as func import torch.utils.checkpoint from timm.models.vision_transformer import Mlp, use_fused_attn from torch.jit import Final from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel +from .configuration_scaledp import ScaleDPPolicyConfig _logger = logging.getLogger(__name__) @@ -52,7 +53,7 @@ class Attention(nn.Module): q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: - x = Func.scaled_dot_product_attention( + x = func.scaled_dot_product_attention( q, k, v, @@ -74,7 +75,7 @@ class Attention(nn.Module): attn_scores += attn_mask # Apply softmax to get attention weights (softmax is applied along the last dimension) - attn_weights = Func.softmax(attn_scores, dim=-1) + attn_weights = func.softmax(attn_scores, dim=-1) # Dropout on attention weights (if dropout is used) attn_weights = self.attn_drop(attn_weights) @@ -191,7 +192,6 @@ class FinalLayer(nn.Module): return x -from .configuration_scaledp import ScaleDPPolicyConfig class ScaleDP(PreTrainedModel): @@ -379,23 +379,23 @@ class ScaleDP(PreTrainedModel): def forward(self, actions, hidden_states, states, is_pad): """ Forward pass for the diffusion head. - :param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1 - :param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [B,Tokens, D] 8 1200 1024 - :param states: robot states, shape [B, D] + :param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024 + :param states: robot states, shape [b, D] :return: loss """ if actions is not None: # training time - B = actions.size(0) + b = actions.size(0) actions = actions[:, : self.num_queries] is_pad = is_pad[:, : self.num_queries] num_noise_samples = self.noise_samples # sample noise to add to actions noise = torch.randn( [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype - ) # num_noise, B, Ta, D(1, 2, 16, 14) + ) # num_noise, b, Ta, D(1, 2, 16, 14) # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) @@ -405,7 +405,7 @@ class ScaleDP(PreTrainedModel): noisy_actions = torch.cat( [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], dim=0, - ) # [num_noise_samples * B, Ta, action_dim] + ) # [num_noise_samples * b, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -425,12 +425,12 @@ class ScaleDP(PreTrainedModel): return {"loss": loss} # return loss else: # inference time - B = 1 - Tp = self.num_queries + b = 1 + tp = self.num_queries action_dim = self.action_dim # initialize action from Gaussian noise - noisy_action = torch.randn((B, Tp, action_dim)).cuda() + noisy_action = torch.randn((b, tp, action_dim)).cuda() naction = noisy_action.to(dtype=hidden_states.dtype) # init scheduler @@ -540,11 +540,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): ################################################################################# -def ScaleDP_H(**kwargs): +def scaledp_h(**kwargs): return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs) -def ScaleDP_L(**kwargs): +def scaledp_l(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index 210f6ba5..9a6a5f98 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -251,23 +251,23 @@ class ConditionalUnet1D(PreTrainedModel): def forward(self, actions, hidden_states, states, is_pad): """ Forward pass for the diffusion head. - :param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1 - :param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [B,Tokens, D] 8 1200 1024 - :param states: robot states, shape [B, D] + :param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1 + :param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024 + :param states: robot states, shape [b, D] :return: loss """ if actions is not None: # training time - B = actions.size(0) + b = actions.size(0) actions = copy.deepcopy(actions[:, : self.num_queries]) is_pad = copy.deepcopy(is_pad[:, : self.num_queries]) num_noise_samples = self.noise_samples # sample noise to add to actions noise = torch.randn( [num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype - ) # num_noise, B, Ta, D + ) # num_noise, b, Ta, D # sample a diffusion iteration for each data point timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device + 0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device ).long() timesteps, noise = timesteps.to(actions.device), noise.to(actions.device) @@ -277,7 +277,7 @@ class ConditionalUnet1D(PreTrainedModel): noisy_actions = torch.cat( [self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))], dim=0, - ) # [num_noise_samples * B, Ta, action_dim] + ) # [num_noise_samples * b, Ta, action_dim] noisy_actions = noisy_actions.to(dtype=actions.dtype) assert hidden_states.ndim == 3 @@ -297,12 +297,12 @@ class ConditionalUnet1D(PreTrainedModel): return {"loss": loss} # return loss else: # inference time - B = 1 - Tp = self.num_queries + b = 1 + tp = self.num_queries action_dim = 14 # initialize action from Gaussian noise - noisy_action = torch.randn((B, Tp, action_dim)).cuda() + noisy_action = torch.randn((b, tp, action_dim)).cuda() naction = noisy_action.to(dtype=hidden_states.dtype) # init scheduler @@ -323,14 +323,14 @@ class ConditionalUnet1D(PreTrainedModel): self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None ): """ - x: (B,T,input_dim) - timestep: (B,) or int, diffusion step - global_cond: (B,global_cond_dim) - output: (B,T,input_dim) + x: (b,T,input_dim) + timestep: (b,) or int, diffusion step + global_cond: (b,global_cond_dim) + output: (b,T,input_dim) """ - # (B,T,C) + # (b,t,c) sample = sample.moveaxis(-1, -2) - # (B,C,T) + # (b,c,t) # global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1) global_cond = global_cond.squeeze(1) @@ -353,7 +353,7 @@ class ConditionalUnet1D(PreTrainedModel): x = sample h = [] - for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): x = resnet(x, global_feature) x = resnet2(x, global_feature) h.append(x) @@ -362,7 +362,7 @@ class ConditionalUnet1D(PreTrainedModel): for mid_module in self.mid_modules: x = mid_module(x, global_feature) - for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) x = resnet2(x, global_feature) @@ -370,9 +370,9 @@ class ConditionalUnet1D(PreTrainedModel): x = self.final_conv(x) - # (B,C,T) + # (b,c,t) x = x.moveaxis(-1, -2) - # (B,T,C) + # (b,t,c) return x diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index b96c1f3d..efbc3c3a 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as func import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm from transformers import AutoConfig, AutoModel @@ -48,7 +48,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from lerobot.common.policies.dexvla.fusion_modules import * +from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig @@ -407,7 +407,7 @@ class VisionSdpaAttention(nn.Module): q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = func.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) @@ -879,7 +879,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False + is_causal = bool(causal_mask is None and q_len > 1) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, @@ -1102,7 +1102,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): cu_seqlens = torch.repeat_interleave( grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0] ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = func.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) @@ -1164,12 +1164,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1281,15 +1280,15 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( + and AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, - ): - return None + ) + ): + return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min @@ -1377,14 +1376,13 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: + if config.sliding_window is not None and (not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length): # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask |= sliding_attend_mask + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: From ed30c7d0a54db3b679e418e22a867c80c69a9617 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 06:31:34 +0000 Subject: [PATCH 17/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/policies/dexvla/policy_heads/modeling_scaledp.py | 3 +-- .../common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index ba5766fa..b09f5d24 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -13,6 +13,7 @@ from timm.models.vision_transformer import Mlp, use_fused_attn from torch.jit import Final from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel + from .configuration_scaledp import ScaleDPPolicyConfig _logger = logging.getLogger(__name__) @@ -193,8 +194,6 @@ class FinalLayer(nn.Module): return x - - class ScaleDP(PreTrainedModel): """ Diffusion models with a Transformer backbone. diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index efbc3c3a..164cd4ab 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -48,7 +48,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM +from lerobot.common.policies.dexvla.fusion_modules import ActionProjector, FiLM from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig @@ -1376,7 +1376,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None and (not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length): + if config.sliding_window is not None and ( + not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length + ): # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not sliding_attend_mask = torch.arange(target_length, device=device) <= ( From ca3027cec20b309adb7d717610597d5633b4fd55 Mon Sep 17 00:00:00 2001 From: wk Date: Tue, 11 Mar 2025 14:39:21 +0800 Subject: [PATCH 18/36] fix_processor --- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 3 +- .../policies/dexvla/robot_data_processor.py | 35 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index efbc3c3a..da47cd1b 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -49,6 +49,7 @@ from transformers.utils import ( ) from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM +from transformers import AutoModelForCausalLM from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig @@ -2047,6 +2048,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs -from transformers import AutoModelForCausalLM + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 4b75c439..db350165 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -30,15 +30,16 @@ class Qwen2VLAProcess: len_views = images.shape[0] messages = self.construct_chat_data(len_views, raw_lang) - data_dict = dict( - messages=messages, - ) + data_dict = { + "messages":messages + } + image_data = torch.chunk(images, len_views, 0) images_list = [] - for i, each in enumerate(image_data): + for _i, each in enumerate(image_data): img_pil = self.qwen2_image_preprocess(each) images_list.append(img_pil) @@ -58,10 +59,7 @@ class Qwen2VLAProcess: return model_inputs input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 - if use_reasoning: - answer = reasoning + "Next action:" + "<|im_end|>" - else: - answer = "" + "<|im_end|>" + answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>" output_text = self.tokenizer(answer, padding=True, return_tensors="pt") output_labels = output_text["input_ids"] @@ -119,15 +117,16 @@ class Qwen2VLAProcess: attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) - batch = dict( - input_ids=input_ids, - attention_mask=attention_mask[0], - labels=labels, - image_grid_spatiotemporal=image_grid_spatiotemporal, - pixel_values_videos=pixel_values_videos, - video_grid_spatiotemporal=video_grid_spatiotemporal, - pixel_values=pixel_values, - ) + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask[0], + "labels": labels, + "image_grid_spatiotemporal": image_grid_spatiotemporal, + "pixel_values_videos": pixel_values_videos, + "video_grid_spatiotemporal": video_grid_spatiotemporal, + "pixel_values": pixel_values, + } + return batch def construct_chat_data(self, len_image, raw_lang): @@ -138,7 +137,7 @@ class Qwen2VLAProcess: }, ] - for i in range(len_image): + for _i in range(len_image): messages[0]["content"].append( { "type": "image", From f321f7c380b31af69a35fe3185d81ac0c44487d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 06:43:21 +0000 Subject: [PATCH 19/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py | 7 ++----- lerobot/common/policies/dexvla/robot_data_processor.py | 5 +---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 98215660..0fd81253 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as func import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm -from transformers import AutoConfig, AutoModel +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers.activations import ACT2FN from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin @@ -48,8 +48,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM -from transformers import AutoModelForCausalLM +from lerobot.common.policies.dexvla.fusion_modules import ActionProjector, FiLM from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig @@ -2050,6 +2049,4 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs - - AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index db350165..d57035db 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -30,10 +30,7 @@ class Qwen2VLAProcess: len_views = images.shape[0] messages = self.construct_chat_data(len_views, raw_lang) - data_dict = { - "messages":messages - } - + data_dict = {"messages": messages} image_data = torch.chunk(images, len_views, 0) From 6bd6a9d63d9db69001d23fa4996e39785b435135 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 15:15:33 +0800 Subject: [PATCH 20/36] Add support for two training --- .../policies/dexvla/configuration_dexvla.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 6f3c0ef0..1519da14 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -23,6 +23,7 @@ from transformers import AutoConfig from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, + ConstantWithWarmupSchedulerConfig ) from transformers.utils import logging from lerobot.configs.policies import PreTrainedConfig @@ -45,9 +46,12 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct' + qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl - pretrained_path: str = None # pretrained dexvla + pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 + pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) + + training_stage: int = 2 # specific training stage, [2, 3] using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -59,7 +63,7 @@ class DexVLAConfig(PreTrainedConfig): optimizer_eps: float = 1e-8 optimizer_weight_decay: float = 1e-10 - scheduler_warmup_steps: int = 1_000 + scheduler_warmup_steps: int = 2_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 @@ -110,6 +114,9 @@ class DexVLAConfig(PreTrainedConfig): else: raise ValueError(f'Policy head type {self.policy_head_type} not supported') + if self.training_stage not in [2,3]: + raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.") + self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) def validate_features(self) -> None: @@ -134,12 +141,17 @@ class DexVLAConfig(PreTrainedConfig): ) def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) + if self.training_stage == 3: + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + else: + return ConstantWithWarmupSchedulerConfig( + num_warmup_steps=self.scheduler_warmup_steps, + ) @property def observation_delta_indices(self) -> None: From 8998ba3bb59f57747b57cf55dbe22ca8fb4a87d3 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 15:15:39 +0800 Subject: [PATCH 21/36] Add support for two training --- .../common/policies/dexvla/modeling_dexvla.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index e9330a79..8db2d0ce 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -11,9 +11,10 @@ from collections import deque from lerobot.common.policies.dexvla.policy_heads.modeling_unet_diffusion import ConditionalUnet1D from lerobot.common.policies.dexvla.policy_heads.modeling_scaledp import ScaleDP from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM import torchvision.transforms as transforms - +import os +from safetensors.torch import load_file class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -47,7 +48,37 @@ class DexVLAPolicy(PreTrainedPolicy): for k in ['using_film', 'llm_loss_weight', 'with_llm_head', 'policy_head_config']: setattr(config.qwen2_vla_config, k, config.__dict__[k]) - self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) + # if self.config.training_stage == 2: + # self.model = Qwen2VLForConditionalGenerationForVLA(config.qwen2_vla_config).to(torch.bfloat16) + model_base = self.config.qwen2_vl_path + self.model = AutoModelForCausalLM.from_pretrained( + model_base, + config=config.qwen2_vla_config, + trust_remote_code=True, + _fast_init=False, + # attn_implementation="flash_attention_2", + ).to(device='cuda', dtype=torch.bfloat16) + + if self.config.pretrained_scaledp_path is not None: + print(f'\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') + pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location='cpu') + + pretrain_scaledp_weights = pretrain_scaledp_weights['nets']['nets'] + + keys_to_del_dit = [] + pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()} + for k in pretrain_scaledp_weights.keys(): + if 'noise_pred' not in k: # del weights of vision backbones + keys_to_del_dit.append(k) + if 'cond_obs_emb' in k: + keys_to_del_dit.append(k) + for k in keys_to_del_dit: + del pretrain_scaledp_weights[k] + pretrain_scaledp_weights = {k[15:] if k.startswith('noise_pred_net.') else k: v for k, v in + pretrain_scaledp_weights.items()} + + self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False) + self.model.requires_grad_(False) self.model.policy_head.requires_grad_(True) self.qwen2_vl_processor = AutoProcessor.from_pretrained(config.qwen2_vl_path) From 105650522ab629a5ee27a886910f100be0763c6f Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 15:16:08 +0800 Subject: [PATCH 22/36] Add a constant_warmup lr scheduler for dexvla --- lerobot/common/optim/schedulers.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 7e158394..486a0b99 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -110,6 +110,29 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) +@LRSchedulerConfig.register_subclass("constant_with_warmup") +@dataclass +class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): + """Used by DexVLA to train Stage2""" + num_warmup_steps: int + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + + def lr_lambda(current_step): + def linear_warmup_schedule(current_step): + if current_step <= 0: + return 1 / (self.num_warmup_steps + 1) + frac = 1 - current_step / self.num_warmup_steps + return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + + def constant_schedule(current_step): + return 1 + + if current_step < self.num_warmup_steps: + return linear_warmup_schedule(current_step) + + return constant_schedule(current_step) + + return LambdaLR(optimizer, lr_lambda, -1) def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict() From cc46d93beb77c432192382162b37ec638c1780a1 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 15:16:22 +0800 Subject: [PATCH 23/36] Add a README.md --- lerobot/common/policies/dexvla/README.md | 112 +++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 lerobot/common/policies/dexvla/README.md diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md new file mode 100644 index 00000000..2f0d7786 --- /dev/null +++ b/lerobot/common/policies/dexvla/README.md @@ -0,0 +1,112 @@ +

+DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning

+ +### This is the lerobot version of DexVLA. For more information, you can refer to [this](https://github.com/juruobenruo/DexVLA). + +## Data Input +DexVLA takes into RGB images, language instructions and states. For our setting, we use three camera views: a top camera, two wrist cameras. + +⭐A major difference between DexVLA with other VLAs is: DexVLA takes raw language in, and outputs sub-step reasoning based on current observations and robot states. +So you have to add sub-step reasoning in your data for training. + +Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation. +For example, if the episode is 10 steps. The length of this list should be 10 as well. And it may looks like: +~~~python +reasoning = [ + "This is step 1.", + "This is step 1.", + "This is step 2.", + "This is step 2.", + ... + "This is step 4.", +] +~~~ + +Besides, your data should include another key ``action_is_pad`` which is a bool mask indicated whether this action chunk is padded. +For example, suppose action chunk is 5, and the length of episode is 10. So the action chunk for last 4 actions must be padded to make sure the length of action chunk is 5. +And the mask looks like: +~~~python +The 6th chunk: [false, false, false, false, true] +The 7th chunk: [false, false, false, true, true] +The 8th chunk: [false, false, true, true, true] +The 9th chunk: [false, true, true, true, true] +~~~ + +## 🤗Download Pretrained Weights +### Download official Qwen2_VL weights +We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. +The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities +for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed +in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link: + +| Model | Link | +|---------------------|----------------------------------------------------------------| +| Qwen2-VL (~2B) | [huggingface](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) | + +**❗❗** After downloading the standard weights, you have to replace the official "config.json" +with our ["config.json"](https://github.com/juruobenruo/DexVLA/blob/main/docs/config.json) designed for VLA. +### Download our pretrained ScaleDP-H weights(Stage 1) +We released our pretrained weights of ScaleDP-H which is trained after Stage1. Now you can download the weights and directly finetuning your data on Stage 2. + +| Model | Link | +|-------------------|----------------------------------------------------------------| +| ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) | +| ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) | + +## 🦾Train +We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3. + +### Training Stage 2 +~~~shell +python lerobot/scripts/train.py \ +--policy.type dexvla \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.ckpt \ +--policy.policy_head_size 'ScaleDP_H' \ +--policy.training_stage 2 \ +--dataset.repo_i folding_blue_tshirt \ +--dataset.local_files_only true \ +--batch_size 2 \ +--policy.using_film true \ +--output_dir /path/to/output \ +--steps 10000 \ +--save_freq 1000 \ +--optimizer_lr 2e-5 \ +--policy.device=cuda +~~~ + +### Training Stage 3 +Stage3 can be viewed as continual training on specific dexterous tasks like laundry folding which is same as PI0. So stage3 is trained based on stage2. +~~~shell +python lerobot/scripts/train.py \ +--policy.type dexvla \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--.pretrained_path /path/to/pretrained/stage2/weights \ +--policy.policy_head_size 'ScaleDP_H' \ +--policy.training_stage 3 \ +--dataset.repo_i folding_blue_tshirt \ +--dataset.local_files_only true \ +--batch_size 2 \ +--policy.using_film true \ +--output_dir /path/to/output \ +--steps 10000 \ +--save_freq 1000 \ +--optimizer_lr 2e-5 \ +--policy.device=cuda +~~~ + +## Evaluation +~~~shell +python lerobot/scripts/eval.py \ +--policy.type dexvla \ +--policy.pretrained_path /path/to/pretrained/stage2/or/stage3/weights \ +--env.type aloha \ +--env.episode_length 5 \ +--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ +--env.task AlohaInsertion-v0 \ +--eval.n_episodes 1 \ +--eval.batch_size 1 \ +--device cuda +~~~ + + From 61e40435aeb7ca1a213bf487a6875d4e7375ee12 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 15:31:39 +0800 Subject: [PATCH 24/36] update README.md --- lerobot/common/policies/dexvla/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 2f0d7786..5f906f1e 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -4,9 +4,9 @@ DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Polic ### This is the lerobot version of DexVLA. For more information, you can refer to [this](https://github.com/juruobenruo/DexVLA). ## Data Input -DexVLA takes into RGB images, language instructions and states. For our setting, we use three camera views: a top camera, two wrist cameras. +DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras. -⭐A major difference between DexVLA with other VLAs is: DexVLA takes raw language in, and outputs sub-step reasoning based on current observations and robot states. +⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations. So you have to add sub-step reasoning in your data for training. Specifically, your data should include a key ``reasoning`` which is a list of sub-step reasoning corresponding to each observation. @@ -22,8 +22,8 @@ reasoning = [ ] ~~~ -Besides, your data should include another key ``action_is_pad`` which is a bool mask indicated whether this action chunk is padded. -For example, suppose action chunk is 5, and the length of episode is 10. So the action chunk for last 4 actions must be padded to make sure the length of action chunk is 5. +Besides, your data should include another key ``action_is_pad`` which is a bool mask indicating whether this action chunk is padded. +Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5. And the mask looks like: ~~~python The 6th chunk: [false, false, false, false, true] From 4407d19704c2fdccca41919643517cdc4bfe3b14 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Mar 2025 07:32:12 +0000 Subject: [PATCH 25/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/optim/schedulers.py | 7 ++-- lerobot/common/policies/dexvla/README.md | 12 +++---- .../policies/dexvla/configuration_dexvla.py | 14 ++++---- .../common/policies/dexvla/modeling_dexvla.py | 32 +++++++++---------- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 486a0b99..e2ebb9e3 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -110,13 +110,15 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) + @LRSchedulerConfig.register_subclass("constant_with_warmup") @dataclass class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): """Used by DexVLA to train Stage2""" - num_warmup_steps: int - def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + num_warmup_steps: int + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: def lr_lambda(current_step): def linear_warmup_schedule(current_step): if current_step <= 0: @@ -134,6 +136,7 @@ class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) + def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict() write_json(state_dict, save_dir / SCHEDULER_STATE) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 5f906f1e..3f2bc310 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -26,7 +26,7 @@ Besides, your data should include another key ``action_is_pad`` which is a bool Suppose the size of the action chunk is 5, and the length of the episode is 10. So the action chunk for the last 4 actions must be padded to make sure the length of action chunk is 5. And the mask looks like: ~~~python -The 6th chunk: [false, false, false, false, true] +The 6th chunk: [false, false, false, false, true] The 7th chunk: [false, false, false, true, true] The 8th chunk: [false, false, true, true, true] The 9th chunk: [false, true, true, true, true] @@ -34,9 +34,9 @@ The 9th chunk: [false, true, true, true, true] ## 🤗Download Pretrained Weights ### Download official Qwen2_VL weights -We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. -The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities -for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed +We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. +The Qwen2-VL 2B serves as the core of our architecture, providing robust capabilities +for vision-language tasks. We use off-the-shelf Qwen2-VL model proposed in [Qwen2-VL](https://arxiv.org/pdf/2409.12191) without any post training on VLM itself. You can download the official weights from this link: | Model | Link | @@ -106,7 +106,5 @@ python lerobot/scripts/eval.py \ --env.task AlohaInsertion-v0 \ --eval.n_episodes 1 \ --eval.batch_size 1 \ ---device cuda +--device cuda ~~~ - - diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 8a54c0d2..6ca58228 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -21,8 +21,8 @@ from transformers.utils import logging from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( + ConstantWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig, - ConstantWithWarmupSchedulerConfig ) from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode @@ -43,12 +43,14 @@ class DexVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 hidden_size: int = 1536 - qwen2_vl_path: str = None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl + qwen2_vl_path: str = ( + None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl + ) - pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 - pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) + pretrained_path: str = None # for loading pretrained weights of whole dexvla, usually for training stage3 + pretrained_scaledp_path: str = None # for loading pretrained weights of ScaleDP(Stage1) - training_stage: int = 2 # specific training stage, [2, 3] + training_stage: int = 2 # specific training stage, [2, 3] using_film: bool = True llm_loss_weight: float = 1.0 with_llm_head: bool = True @@ -115,7 +117,7 @@ class DexVLAConfig(PreTrainedConfig): else: raise ValueError(f"Policy head type {self.policy_head_type} not supported") - if self.training_stage not in [2,3]: + if self.training_stage not in [2, 3]: raise ValueError(f"Training stage must be 2 or 3. Got {self.training_stage}.") self.qwen2_vla_config = AutoConfig.from_pretrained(self.qwen2_vl_path) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index c9513d07..ec58f783 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -3,20 +3,14 @@ from collections import deque import torch import torchvision.transforms as transforms from torch import Tensor -from transformers import AutoProcessor, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig -from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy -from collections import deque -from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM -import torchvision.transforms as transforms -import os -from safetensors.torch import load_file class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -59,25 +53,31 @@ class DexVLAPolicy(PreTrainedPolicy): trust_remote_code=True, _fast_init=False, # attn_implementation="flash_attention_2", - ).to(device='cuda', dtype=torch.bfloat16) + ).to(device="cuda", dtype=torch.bfloat16) if self.config.pretrained_scaledp_path is not None: - print(f'\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') - pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location='cpu') + print( + "\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ) + pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu") - pretrain_scaledp_weights = pretrain_scaledp_weights['nets']['nets'] + pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"] keys_to_del_dit = [] - pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()} + pretrain_scaledp_weights = { + k[7:] if k.startswith("policy.") else k: v for k, v in pretrain_scaledp_weights.items() + } for k in pretrain_scaledp_weights.keys(): - if 'noise_pred' not in k: # del weights of vision backbones + if "noise_pred" not in k: # del weights of vision backbones keys_to_del_dit.append(k) - if 'cond_obs_emb' in k: + if "cond_obs_emb" in k: keys_to_del_dit.append(k) for k in keys_to_del_dit: del pretrain_scaledp_weights[k] - pretrain_scaledp_weights = {k[15:] if k.startswith('noise_pred_net.') else k: v for k, v in - pretrain_scaledp_weights.items()} + pretrain_scaledp_weights = { + k[15:] if k.startswith("noise_pred_net.") else k: v + for k, v in pretrain_scaledp_weights.items() + } self.model.policy_head.load_state_dict(pretrain_scaledp_weights, strict=False) From b4853011f8687c72fd2a16b37e0a2bd4e48b5552 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 16:47:33 +0800 Subject: [PATCH 26/36] add __init__.py --- .../common/policies/dexvla/policy_heads/__init__.py | 13 +++++++++++++ .../dexvla/policy_heads/configuration_scaledp.py | 3 +-- .../policy_heads/configuration_unet_diffusion.py | 3 +-- .../dexvla/policy_heads/modeling_scaledp.py | 2 -- .../dexvla/policy_heads/modeling_unet_diffusion.py | 2 -- lerobot/common/policies/dexvla/qwe2_vla/__init__.py | 11 +++++++++++ .../dexvla/qwe2_vla/configuration_qwen2_vla.py | 2 -- .../policies/dexvla/qwe2_vla/modeling_qwen2_vla.py | 3 +-- 8 files changed, 27 insertions(+), 12 deletions(-) create mode 100644 lerobot/common/policies/dexvla/policy_heads/__init__.py create mode 100644 lerobot/common/policies/dexvla/qwe2_vla/__init__.py diff --git a/lerobot/common/policies/dexvla/policy_heads/__init__.py b/lerobot/common/policies/dexvla/policy_heads/__init__.py new file mode 100644 index 00000000..991e8e0b --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/__init__.py @@ -0,0 +1,13 @@ +from .configuration_scaledp import ScaleDPPolicyConfig +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from .modeling_scaledp import ScaleDP +from .modeling_unet_diffusion import ConditionalUnet1D +from transformers import AutoConfig, AutoModel + + +def register_policy_heads(): + AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) + AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) + AutoModel.register(ScaleDPPolicyConfig, ScaleDP) + AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) + diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 385c8dc1..8ccc6196 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -1,7 +1,7 @@ import os from typing import Union -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) @@ -106,4 +106,3 @@ class ScaleDPPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) -AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index 6ca6fcbe..3c9af4d9 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -1,7 +1,7 @@ import os from typing import Union -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) @@ -70,4 +70,3 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) -AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index b09f5d24..5b8217fb 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -11,7 +11,6 @@ import torch.nn.functional as func import torch.utils.checkpoint from timm.models.vision_transformer import Mlp, use_fused_attn from torch.jit import Final -from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel from .configuration_scaledp import ScaleDPPolicyConfig @@ -548,4 +547,3 @@ def scaledp_l(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) -AutoModel.register(ScaleDPPolicyConfig, ScaleDP) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index 9a6a5f98..dc227ccb 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -11,7 +11,6 @@ import torch.nn as nn # requires diffusers==0.11.1 from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel from .configuration_unet_diffusion import UnetDiffusionPolicyConfig @@ -376,4 +375,3 @@ class ConditionalUnet1D(PreTrainedModel): return x -AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/__init__.py b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py new file mode 100644 index 00000000..23c7b636 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py @@ -0,0 +1,11 @@ +from .configuration_qwen2_vla import Qwen2VLAConfig +from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA + +from transformers import AutoConfig, AutoModelForCausalLM + + +def register_qwen2_vla(): + AutoConfig.register("qwen2_vla", Qwen2VLAConfig) + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) + + diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index 80717bc2..e3ea55b8 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -16,7 +16,6 @@ import os from typing import Union -from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging @@ -254,4 +253,3 @@ class Qwen2VLAConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -AutoConfig.register("qwen2_vla", Qwen2VLAConfig) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 0fd81253..fa06a7b3 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as func import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin @@ -2049,4 +2049,3 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs -AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) From fcb20473104acdd619164f25a686a7cfb80588f9 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 16:47:54 +0800 Subject: [PATCH 27/36] add register policy_head and qwen2_vla --- lerobot/common/policies/dexvla/configuration_dexvla.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 8a54c0d2..5c1f6743 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -28,14 +28,18 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) +from .policy_heads import register_policy_heads +from .qwe2_vla import register_qwen2_vla +register_policy_heads() +register_qwen2_vla() @PreTrainedConfig.register_subclass("dexvla") @dataclass class DexVLAConfig(PreTrainedConfig): # For loading policy head policy_head_type: str = "scale_dp_policy" - policy_head_size: str = "ScaleDP_L" + policy_head_size: str = "scaledp_l" action_dim: int = 14 state_dim: int = 14 chunk_size: int = 50 From 3dd20bad4be5c4b82b2d2347af88e1087116e548 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 16:48:09 +0800 Subject: [PATCH 28/36] remove unused code --- lerobot/common/policies/dexvla/modeling_dexvla.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index c9513d07..48fdd2ad 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1,12 +1,9 @@ from collections import deque import torch -import torchvision.transforms as transforms from torch import Tensor -from transformers import AutoProcessor, AutoTokenizer from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig -from lerobot.common.policies.dexvla.qwe2_vla.modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -15,8 +12,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from collections import deque from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM import torchvision.transforms as transforms -import os -from safetensors.torch import load_file + class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -69,7 +65,7 @@ class DexVLAPolicy(PreTrainedPolicy): keys_to_del_dit = [] pretrain_scaledp_weights = {k[7:] if k.startswith('policy.') else k: v for k, v in pretrain_scaledp_weights.items()} - for k in pretrain_scaledp_weights.keys(): + for k in pretrain_scaledp_weights: if 'noise_pred' not in k: # del weights of vision backbones keys_to_del_dit.append(k) if 'cond_obs_emb' in k: From 435463e3c9e212161972b56b8722153983a6859d Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 16:49:00 +0800 Subject: [PATCH 29/36] update README.md --- lerobot/common/policies/dexvla/README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 5f906f1e..51f6917f 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -62,10 +62,9 @@ python lerobot/scripts/train.py \ --policy.type dexvla \ --policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ --policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.ckpt \ ---policy.policy_head_size 'ScaleDP_H' \ +--policy.policy_head_size 'scaledp_h' \ --policy.training_stage 2 \ --dataset.repo_i folding_blue_tshirt \ ---dataset.local_files_only true \ --batch_size 2 \ --policy.using_film true \ --output_dir /path/to/output \ @@ -82,10 +81,9 @@ python lerobot/scripts/train.py \ --policy.type dexvla \ --policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ --.pretrained_path /path/to/pretrained/stage2/weights \ ---policy.policy_head_size 'ScaleDP_H' \ +--policy.policy_head_size 'scaledp_h' \ --policy.training_stage 3 \ --dataset.repo_i folding_blue_tshirt \ ---dataset.local_files_only true \ --batch_size 2 \ --policy.using_film true \ --output_dir /path/to/output \ From 0bb966585d2611b51219b81579b5afc357742caf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Mar 2025 08:52:36 +0000 Subject: [PATCH 30/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/dexvla/configuration_dexvla.py | 1 + lerobot/common/policies/dexvla/modeling_dexvla.py | 6 ++---- lerobot/common/policies/dexvla/policy_heads/__init__.py | 4 ++-- .../policies/dexvla/policy_heads/configuration_scaledp.py | 2 -- .../dexvla/policy_heads/configuration_unet_diffusion.py | 2 -- .../common/policies/dexvla/policy_heads/modeling_scaledp.py | 2 -- .../policies/dexvla/policy_heads/modeling_unet_diffusion.py | 2 -- lerobot/common/policies/dexvla/qwe2_vla/__init__.py | 6 ++---- .../policies/dexvla/qwe2_vla/configuration_qwen2_vla.py | 2 -- .../common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py | 2 -- 10 files changed, 7 insertions(+), 22 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index f666b041..6462b3cc 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -34,6 +34,7 @@ from .qwe2_vla import register_qwen2_vla register_policy_heads() register_qwen2_vla() + @PreTrainedConfig.register_subclass("dexvla") @dataclass class DexVLAConfig(PreTrainedConfig): diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 7c321c1b..8da0ed06 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -1,7 +1,9 @@ from collections import deque import torch +import torchvision.transforms as transforms from torch import Tensor +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig from lerobot.common.policies.dexvla.robot_data_processor import Qwen2VLAProcess @@ -9,10 +11,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy -from collections import deque -from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM -import torchvision.transforms as transforms - class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" diff --git a/lerobot/common/policies/dexvla/policy_heads/__init__.py b/lerobot/common/policies/dexvla/policy_heads/__init__.py index 991e8e0b..bf25637c 100644 --- a/lerobot/common/policies/dexvla/policy_heads/__init__.py +++ b/lerobot/common/policies/dexvla/policy_heads/__init__.py @@ -1,8 +1,9 @@ +from transformers import AutoConfig, AutoModel + from .configuration_scaledp import ScaleDPPolicyConfig from .configuration_unet_diffusion import UnetDiffusionPolicyConfig from .modeling_scaledp import ScaleDP from .modeling_unet_diffusion import ConditionalUnet1D -from transformers import AutoConfig, AutoModel def register_policy_heads(): @@ -10,4 +11,3 @@ def register_policy_heads(): AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) AutoModel.register(ScaleDPPolicyConfig, ScaleDP) AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) - diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 8ccc6196..d84caf48 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -104,5 +104,3 @@ class ScaleDPPolicyConfig(PretrainedConfig): ) return cls.from_dict(config_dict, **kwargs) - - diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index 3c9af4d9..3f40cbef 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -68,5 +68,3 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): ) return cls.from_dict(config_dict, **kwargs) - - diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index 5b8217fb..8a96eefe 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -545,5 +545,3 @@ def scaledp_h(**kwargs): def scaledp_l(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) - - diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index dc227ccb..0631d9b3 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -373,5 +373,3 @@ class ConditionalUnet1D(PreTrainedModel): x = x.moveaxis(-1, -2) # (b,t,c) return x - - diff --git a/lerobot/common/policies/dexvla/qwe2_vla/__init__.py b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py index 23c7b636..38a5cbca 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/__init__.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py @@ -1,11 +1,9 @@ +from transformers import AutoConfig, AutoModelForCausalLM + from .configuration_qwen2_vla import Qwen2VLAConfig from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA -from transformers import AutoConfig, AutoModelForCausalLM - def register_qwen2_vla(): AutoConfig.register("qwen2_vla", Qwen2VLAConfig) AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) - - diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index e3ea55b8..628bca77 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -251,5 +251,3 @@ class Qwen2VLAConfig(PretrainedConfig): rope_config_validation(self, ignore_keys={"mrope_section"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - - diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index fa06a7b3..2fccd565 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -2047,5 +2047,3 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi ) model_inputs.update(kwargs) return model_inputs - - From 8d03cc8ad2f8188e9192d69db7fe0503b2548c62 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 18:03:05 +0800 Subject: [PATCH 31/36] replace torch.load with safe_open --- lerobot/common/policies/dexvla/configuration_dexvla.py | 5 ++--- lerobot/common/policies/dexvla/modeling_dexvla.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index f666b041..176d9a51 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Qwen2VL model configuration""" +from .policy_heads import register_policy_heads +from .qwe2_vla import register_qwen2_vla from dataclasses import dataclass, field from typing import Tuple @@ -28,9 +30,6 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) -from .policy_heads import register_policy_heads -from .qwe2_vla import register_qwen2_vla - register_policy_heads() register_qwen2_vla() diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index 7c321c1b..43399bab 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -12,7 +12,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from collections import deque from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM import torchvision.transforms as transforms - +from safetensors.torch import load_file class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" @@ -61,9 +61,7 @@ class DexVLAPolicy(PreTrainedPolicy): print( "\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" ) - pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu") - - pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"] + pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path) keys_to_del_dit = [] pretrain_scaledp_weights = { From e3bab739a3ae01d1ba7514af9c5eea2ca0c6aa74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Mar 2025 10:22:27 +0000 Subject: [PATCH 32/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/dexvla/configuration_dexvla.py | 5 +++-- lerobot/common/policies/dexvla/modeling_dexvla.py | 5 +---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 191cdd0b..76304057 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Qwen2VL model configuration""" -from .policy_heads import register_policy_heads -from .qwe2_vla import register_qwen2_vla from dataclasses import dataclass, field from typing import Tuple @@ -29,6 +27,9 @@ from lerobot.common.optim.schedulers import ( from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode +from .policy_heads import register_policy_heads +from .qwe2_vla import register_qwen2_vla + logger = logging.get_logger(__name__) register_policy_heads() register_qwen2_vla() diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index bbd30907..b90535fa 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -2,6 +2,7 @@ from collections import deque import torch import torchvision.transforms as transforms +from safetensors.torch import load_file from torch import Tensor from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer @@ -11,10 +12,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy -from collections import deque -from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM -import torchvision.transforms as transforms -from safetensors.torch import load_file class DexVLAPolicy(PreTrainedPolicy): """Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot.""" From 172fd09fca21aec7504506e56c4cf3ced7471adc Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Tue, 18 Mar 2025 18:30:49 +0800 Subject: [PATCH 33/36] add weights format transform --- lerobot/common/policies/dexvla/README.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 4d832c01..75b4171e 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -53,6 +53,21 @@ We released our pretrained weights of ScaleDP-H which is trained after Stage1. N | ScaleDP-H (~1B) | [huggingface](https://huggingface.co/lesjie/scale_dp_h) | | ScaleDP-L (~400M) | [huggingface](https://huggingface.co/lesjie/scale_dp_l) | +**❗❗**After downloading the weights, you have to transform it into ``safetensors`` format, you can simply run this code: +~~~python +import torch +from safetensors.torch import save_file +path = "/path/to/open_scale_dp_l_backbone.ckpt" +checkpoint = torch.load(path, map_location=torch.device('cpu'))['nets']['nets'] + +# Save the weights in safetensors format +safetensors_path = "/path/to/open_scale_dp_l_backbone.safetensors" +save_file(checkpoint, safetensors_path) +print(f"Converted {path} to {safetensors_path}") +pass + +~~~ + ## 🦾Train We have already provided pretrained weights of ScaleDP which is stage 1. Belows are mainly about training process of Stage2 and Stage3. @@ -61,7 +76,7 @@ We have already provided pretrained weights of ScaleDP which is stage 1. Belows python lerobot/scripts/train.py \ --policy.type dexvla \ --policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ ---policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.ckpt \ +--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \ --policy.policy_head_size 'scaledp_h' \ --policy.training_stage 2 \ --dataset.repo_i folding_blue_tshirt \ From 3c5bb6b0d6aac7ae2141e1bec2bd308e6640a7d1 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Wed, 19 Mar 2025 13:58:28 +0800 Subject: [PATCH 34/36] update inference, training, dataset --- lerobot/common/policies/dexvla/README.md | 25 ++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 75b4171e..9d0b9805 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -1,9 +1,11 @@

DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning

-### This is the lerobot version of DexVLA. For more information, you can refer to [this](https://github.com/juruobenruo/DexVLA). +This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA). +This is [project website](https://dex-vla.github.io/). -## Data Input +## Dataset +### Data format DexVLA takes RGB images, language instructions and states. For our setting, we use three camera views, namely a top camera and two wrist cameras. ⭐A major difference between DexVLA and other VLAs is: DexVLA takes in raw language, and outputs sub-step reasoning based on current observations. @@ -32,6 +34,10 @@ The 8th chunk: [false, false, true, true, true] The 9th chunk: [false, true, true, true, true] ~~~ +### Training Data for DexVLA +The pretraining dataset comprises approximately 100 hours of collected data by ourselves. The dataset mainly including four embodiments which are: moblie Agilex Aloha, single Franka Emika and single UR5e. +We haven't use any public dataset such as Open-X or DROID. + ## 🤗Download Pretrained Weights ### Download official Qwen2_VL weights We construct the VLM backbone by integrating Qwen2-VL-2B, a powerful and efficient model, into our framework. @@ -108,7 +114,19 @@ python lerobot/scripts/train.py \ --policy.device=cuda ~~~ +### Training Time +Original DexVLA is trained on 8 x H100 GPUs. And the training time for each stage is listed as follows: + +| Stage | Batch Size(each gpu) | Steps | Time(hour) | +|--------|----------------------|--------|------------| +| Stage1 | 32 | 60000 | 30 | +| Stage2 | 12 | 100000 | 30 | +| Stage3 | 12 | 60000 | 18 | + + ## Evaluation +### Evaluation Script +You can evaluate dexvla by following scripts. ~~~shell python lerobot/scripts/eval.py \ --policy.type dexvla \ @@ -121,3 +139,6 @@ python lerobot/scripts/eval.py \ --eval.batch_size 1 \ --device cuda ~~~ + +### Inference Speed +Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz. \ No newline at end of file From 15c9ecdf52218b571fa9eef255c5bbddf25e7154 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Mar 2025 05:58:57 +0000 Subject: [PATCH 35/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/policies/dexvla/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index 9d0b9805..b34a40bb 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -2,7 +2,7 @@ DexVLA: Vision-Language Model with Plug-In Diffusion Expert for Visuomotor Policy Learning This policy is Community Contributed. For more information about DexVLA, you can also refer to [this](https://github.com/juruobenruo/DexVLA). -This is [project website](https://dex-vla.github.io/). +This is [project website](https://dex-vla.github.io/). ## Dataset ### Data format @@ -141,4 +141,4 @@ python lerobot/scripts/eval.py \ ~~~ ### Inference Speed -Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz. \ No newline at end of file +Tested on a single A6000 GPU, the DexVLA could infer 3.4 action chunks in one second. For each action chunk, if we execute 25 actions, the real control frequency can be 85 (3.4*25)Hz. From aeddad910b4f71859b78a84ddc7ccf5f6288ec05 Mon Sep 17 00:00:00 2001 From: lesjie-wen <870351470@qq.com> Date: Thu, 20 Mar 2025 13:32:15 +0800 Subject: [PATCH 36/36] fix dexvla requirements version mismatch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10305fee..f7daca2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] -dexvla = ["transformers>=4.45.2", "qwen_vl_utils>=0.08", "timm==0.9.10"] +dexvla = ["transformers>=4.45.2", "qwen_vl_utils==0.0.10", "timm==0.9.10"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",