This commit is contained in:
wk 2025-03-11 14:31:12 +08:00
parent b83cb0ba89
commit e2a41716a4
4 changed files with 57 additions and 59 deletions

View File

@ -7,12 +7,12 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
MODEL_STRUCTURE = {
"ScaleDP_H": {
"scaledp_h": {
"depth": 32,
"n_emb": 1280,
"num_heads": 16,
},
"ScaleDP_L": {
"scaledp_l": {
"depth": 24,
"n_emb": 1024,
"num_heads": 16,

View File

@ -8,12 +8,13 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Func
import torch.nn.functional as func
import torch.utils.checkpoint
from timm.models.vision_transformer import Mlp, use_fused_attn
from torch.jit import Final
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
from .configuration_scaledp import ScaleDPPolicyConfig
_logger = logging.getLogger(__name__)
@ -52,7 +53,7 @@ class Attention(nn.Module):
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = Func.scaled_dot_product_attention(
x = func.scaled_dot_product_attention(
q,
k,
v,
@ -74,7 +75,7 @@ class Attention(nn.Module):
attn_scores += attn_mask
# Apply softmax to get attention weights (softmax is applied along the last dimension)
attn_weights = Func.softmax(attn_scores, dim=-1)
attn_weights = func.softmax(attn_scores, dim=-1)
# Dropout on attention weights (if dropout is used)
attn_weights = self.attn_drop(attn_weights)
@ -191,7 +192,6 @@ class FinalLayer(nn.Module):
return x
from .configuration_scaledp import ScaleDPPolicyConfig
class ScaleDP(PreTrainedModel):
@ -379,23 +379,23 @@ class ScaleDP(PreTrainedModel):
def forward(self, actions, hidden_states, states, is_pad):
"""
Forward pass for the diffusion head.
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [B,Tokens, D] 8 1200 1024
:param states: robot states, shape [B, D]
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
:param states: robot states, shape [b, D]
:return: loss
"""
if actions is not None: # training time
B = actions.size(0)
b = actions.size(0)
actions = actions[:, : self.num_queries]
is_pad = is_pad[:, : self.num_queries]
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, B, Ta, D(1, 2, 16, 14)
) # num_noise, b, Ta, D(1, 2, 16, 14)
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
@ -405,7 +405,7 @@ class ScaleDP(PreTrainedModel):
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * B, Ta, action_dim]
) # [num_noise_samples * b, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
@ -425,12 +425,12 @@ class ScaleDP(PreTrainedModel):
return {"loss": loss}
# return loss
else: # inference time
B = 1
Tp = self.num_queries
b = 1
tp = self.num_queries
action_dim = self.action_dim
# initialize action from Gaussian noise
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
noisy_action = torch.randn((b, tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)
# init scheduler
@ -540,11 +540,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
#################################################################################
def ScaleDP_H(**kwargs):
def scaledp_h(**kwargs):
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
def ScaleDP_L(**kwargs):
def scaledp_l(**kwargs):
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)

View File

@ -251,23 +251,23 @@ class ConditionalUnet1D(PreTrainedModel):
def forward(self, actions, hidden_states, states, is_pad):
"""
Forward pass for the diffusion head.
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [B,Tokens, D] 8 1200 1024
:param states: robot states, shape [B, D]
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
:param states: robot states, shape [b, D]
:return: loss
"""
if actions is not None: # training time
B = actions.size(0)
b = actions.size(0)
actions = copy.deepcopy(actions[:, : self.num_queries])
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
num_noise_samples = self.noise_samples
# sample noise to add to actions
noise = torch.randn(
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
) # num_noise, B, Ta, D
) # num_noise, b, Ta, D
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
).long()
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
@ -277,7 +277,7 @@ class ConditionalUnet1D(PreTrainedModel):
noisy_actions = torch.cat(
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
dim=0,
) # [num_noise_samples * B, Ta, action_dim]
) # [num_noise_samples * b, Ta, action_dim]
noisy_actions = noisy_actions.to(dtype=actions.dtype)
assert hidden_states.ndim == 3
@ -297,12 +297,12 @@ class ConditionalUnet1D(PreTrainedModel):
return {"loss": loss}
# return loss
else: # inference time
B = 1
Tp = self.num_queries
b = 1
tp = self.num_queries
action_dim = 14
# initialize action from Gaussian noise
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
noisy_action = torch.randn((b, tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)
# init scheduler
@ -323,14 +323,14 @@ class ConditionalUnet1D(PreTrainedModel):
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
x: (b,T,input_dim)
timestep: (b,) or int, diffusion step
global_cond: (b,global_cond_dim)
output: (b,T,input_dim)
"""
# (B,T,C)
# (b,t,c)
sample = sample.moveaxis(-1, -2)
# (B,C,T)
# (b,c,t)
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
global_cond = global_cond.squeeze(1)
@ -353,7 +353,7 @@ class ConditionalUnet1D(PreTrainedModel):
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
@ -362,7 +362,7 @@ class ConditionalUnet1D(PreTrainedModel):
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
@ -370,9 +370,9 @@ class ConditionalUnet1D(PreTrainedModel):
x = self.final_conv(x)
# (B,C,T)
# (b,c,t)
x = x.moveaxis(-1, -2)
# (B,T,C)
# (b,t,c)
return x

View File

@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as func
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers import AutoConfig, AutoModel
@ -48,7 +48,7 @@ from transformers.utils import (
replace_return_docstrings,
)
from lerobot.common.policies.dexvla.fusion_modules import *
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
@ -407,7 +407,7 @@ class VisionSdpaAttention(nn.Module):
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = func.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
@ -879,7 +879,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
is_causal = bool(causal_mask is None and q_len > 1)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
@ -1102,7 +1102,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
cu_seqlens = torch.repeat_interleave(
grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = func.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
@ -1164,12 +1164,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -1281,15 +1280,15 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
and AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
)
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
@ -1377,14 +1376,13 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
if config.sliding_window is not None and (not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length):
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask |= sliding_attend_mask
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: