cyf_fix
This commit is contained in:
parent
b83cb0ba89
commit
e2a41716a4
|
@ -7,12 +7,12 @@ from transformers.utils import logging
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
MODEL_STRUCTURE = {
|
||||
"ScaleDP_H": {
|
||||
"scaledp_h": {
|
||||
"depth": 32,
|
||||
"n_emb": 1280,
|
||||
"num_heads": 16,
|
||||
},
|
||||
"ScaleDP_L": {
|
||||
"scaledp_l": {
|
||||
"depth": 24,
|
||||
"n_emb": 1024,
|
||||
"num_heads": 16,
|
||||
|
|
|
@ -8,12 +8,13 @@ import numpy as np
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as Func
|
||||
import torch.nn.functional as func
|
||||
import torch.utils.checkpoint
|
||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||
from torch.jit import Final
|
||||
from transformers import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,7 +53,7 @@ class Attention(nn.Module):
|
|||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = Func.scaled_dot_product_attention(
|
||||
x = func.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
@ -74,7 +75,7 @@ class Attention(nn.Module):
|
|||
attn_scores += attn_mask
|
||||
|
||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||
attn_weights = Func.softmax(attn_scores, dim=-1)
|
||||
attn_weights = func.softmax(attn_scores, dim=-1)
|
||||
|
||||
# Dropout on attention weights (if dropout is used)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
|
@ -191,7 +192,6 @@ class FinalLayer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
|
||||
|
||||
class ScaleDP(PreTrainedModel):
|
||||
|
@ -379,23 +379,23 @@ class ScaleDP(PreTrainedModel):
|
|||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [B,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [B, D]
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the conScaleDPion for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
B = actions.size(0)
|
||||
b = actions.size(0)
|
||||
actions = actions[:, : self.num_queries]
|
||||
is_pad = is_pad[:, : self.num_queries]
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, B, Ta, D(1, 2, 16, 14)
|
||||
) # num_noise, b, Ta, D(1, 2, 16, 14)
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
@ -405,7 +405,7 @@ class ScaleDP(PreTrainedModel):
|
|||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * B, Ta, action_dim]
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
@ -425,12 +425,12 @@ class ScaleDP(PreTrainedModel):
|
|||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
B = 1
|
||||
Tp = self.num_queries
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = self.action_dim
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
|
@ -540,11 +540,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|||
#################################################################################
|
||||
|
||||
|
||||
def ScaleDP_H(**kwargs):
|
||||
def scaledp_h(**kwargs):
|
||||
return ScaleDP(depth=32, n_emb=1280, num_heads=16, **kwargs)
|
||||
|
||||
|
||||
def ScaleDP_L(**kwargs):
|
||||
def scaledp_l(**kwargs):
|
||||
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -251,23 +251,23 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
def forward(self, actions, hidden_states, states, is_pad):
|
||||
"""
|
||||
Forward pass for the diffusion head.
|
||||
:param actions: target actions, shape [B, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [B,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [B, D]
|
||||
:param actions: target actions, shape [b, Ta, D] D:10 = 3+6+1
|
||||
:param hidden_states: hidden states from the llava_pythia, as the condition for the diffusion, shape [b,Tokens, D] 8 1200 1024
|
||||
:param states: robot states, shape [b, D]
|
||||
:return: loss
|
||||
"""
|
||||
if actions is not None: # training time
|
||||
B = actions.size(0)
|
||||
b = actions.size(0)
|
||||
actions = copy.deepcopy(actions[:, : self.num_queries])
|
||||
is_pad = copy.deepcopy(is_pad[:, : self.num_queries])
|
||||
num_noise_samples = self.noise_samples
|
||||
# sample noise to add to actions
|
||||
noise = torch.randn(
|
||||
[num_noise_samples] + list(actions.shape), device=actions.device, dtype=actions.dtype
|
||||
) # num_noise, B, Ta, D
|
||||
) # num_noise, b, Ta, D
|
||||
# sample a diffusion iteration for each data point
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (B,), device=actions.device
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (b,), device=actions.device
|
||||
).long()
|
||||
|
||||
timesteps, noise = timesteps.to(actions.device), noise.to(actions.device)
|
||||
|
@ -277,7 +277,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
noisy_actions = torch.cat(
|
||||
[self.noise_scheduler.add_noise(actions, noise[i], timesteps) for i in range(len(noise))],
|
||||
dim=0,
|
||||
) # [num_noise_samples * B, Ta, action_dim]
|
||||
) # [num_noise_samples * b, Ta, action_dim]
|
||||
|
||||
noisy_actions = noisy_actions.to(dtype=actions.dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
|
@ -297,12 +297,12 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
return {"loss": loss}
|
||||
# return loss
|
||||
else: # inference time
|
||||
B = 1
|
||||
Tp = self.num_queries
|
||||
b = 1
|
||||
tp = self.num_queries
|
||||
action_dim = 14
|
||||
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
||||
noisy_action = torch.randn((b, tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
# init scheduler
|
||||
|
@ -323,14 +323,14 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None, states=None
|
||||
):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
x: (b,T,input_dim)
|
||||
timestep: (b,) or int, diffusion step
|
||||
global_cond: (b,global_cond_dim)
|
||||
output: (b,T,input_dim)
|
||||
"""
|
||||
# (B,T,C)
|
||||
# (b,t,c)
|
||||
sample = sample.moveaxis(-1, -2)
|
||||
# (B,C,T)
|
||||
# (b,c,t)
|
||||
# global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
|
@ -353,7 +353,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
for _idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
|
@ -362,7 +362,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
for _idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
|
@ -370,9 +370,9 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
|
||||
x = self.final_conv(x)
|
||||
|
||||
# (B,C,T)
|
||||
# (b,c,t)
|
||||
x = x.moveaxis(-1, -2)
|
||||
# (B,T,C)
|
||||
# (b,t,c)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.functional as func
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.dexvla.fusion_modules import *
|
||||
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
|
||||
|
||||
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
||||
|
||||
|
@ -407,7 +407,7 @@ class VisionSdpaAttention(nn.Module):
|
|||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = func.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
|
@ -879,7 +879,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
|||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
is_causal = bool(causal_mask is None and q_len > 1)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
|
@ -1102,7 +1102,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
cu_seqlens = func.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
|
@ -1164,12 +1164,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -1281,15 +1280,15 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
and AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1377,14 +1376,13 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
if config.sliding_window is not None:
|
||||
if config.sliding_window is not None and (not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length):
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||
cache_position.reshape(-1, 1) - config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask |= sliding_attend_mask
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
|
|
Loading…
Reference in New Issue