fix modeling

This commit is contained in:
wk 2025-03-11 14:02:30 +08:00
parent 9f4d490423
commit d927a90762
4 changed files with 27 additions and 29 deletions

View File

@ -19,7 +19,7 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
action_dim=10,
global_cond_dim=2048,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
down_dims=None,
kernel_size=5,
n_groups=8,
state_dim=7,
@ -29,6 +29,8 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
num_train_timesteps=100,
**kwargs,
):
if down_dims is None:
down_dims = [256, 512, 1024]
self.input_dim = action_dim
self.noise_samples = noise_samples
self.prediction_horizon = prediction_horizon

View File

@ -6,14 +6,9 @@ from typing import Tuple
import numpy as np
try:
from typing import Literal
except ImportError:
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as Func
import torch.utils.checkpoint
from timm.models.vision_transformer import Mlp, use_fused_attn
from torch.jit import Final
@ -51,13 +46,13 @@ class Attention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
x = Func.scaled_dot_product_attention(
q,
k,
v,
@ -79,7 +74,7 @@ class Attention(nn.Module):
attn_scores += attn_mask
# Apply softmax to get attention weights (softmax is applied along the last dimension)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = Func.softmax(attn_scores, dim=-1)
# Dropout on attention weights (if dropout is used)
attn_weights = self.attn_drop(attn_weights)
@ -87,7 +82,7 @@ class Attention(nn.Module):
# Apply attention weights to value tensor (V)
x = torch.matmul(attn_weights, v)
x = x.transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -162,7 +157,8 @@ class ScaleDPBlock(nn.Module):
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
def approx_gelu():
return nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
@ -213,15 +209,15 @@ class ScaleDP(PreTrainedModel):
# compute number of tokens for main trunk and conScaleDPion encoder
if config.n_obs_steps is None:
config.n_obs_steps = config.prediction_horizon
T = config.prediction_horizon
T_cond = 1
t = config.prediction_horizon
t_cond = 1
if not config.time_as_cond:
T += 1
T_cond -= 1
t += 1
t_cond -= 1
obs_as_cond = config.cond_dim > 0
if obs_as_cond:
assert config.time_as_cond
T_cond += config.n_obs_steps
t_cond += config.n_obs_steps
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
self.combine = nn.Sequential(
@ -254,8 +250,8 @@ class ScaleDP(PreTrainedModel):
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
# self.initialize_weights()
# constants
self.T = T
self.T_cond = T_cond
self.t = t
self.t_cond = t_cond
self.prediction_horizon = config.prediction_horizon
self.time_as_cond = config.time_as_cond
self.action_dim = config.output_dim
@ -328,7 +324,7 @@ class ScaleDP(PreTrainedModel):
whitelist_weight_modules = (torch.nn.Linear, Attention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
for pn, _p in m.named_parameters():
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
if pn.endswith("bias"):
@ -345,7 +341,7 @@ class ScaleDP(PreTrainedModel):
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
@ -360,11 +356,11 @@ class ScaleDP(PreTrainedModel):
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"params": [param_dict[pn] for pn in sorted(decay)],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"params": [param_dict[pn] for pn in sorted(no_decay)],
"weight_decay": 0.0,
},
]

View File

@ -128,7 +128,7 @@ class ConditionalUnet1D(PreTrainedModel):
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines numebr of levels.
The length of this array determines number of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
@ -301,7 +301,7 @@ class ConditionalUnet1D(PreTrainedModel):
Tp = self.num_queries
action_dim = 14
# initialize action from Guassian noise
# initialize action from Gaussian noise
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)

View File

@ -1937,12 +1937,12 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
inputs_index = inputs_index.int()
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
indexs = torch.argmax((xor_array != 0).float(), dim=1)
indexes = torch.argmax((xor_array != 0).float(), dim=1)
input_embeddings = []
reasoning_embeddings = []
identity = []
for i in range(indexs.shape[0]):
end = indexs[i] + 1
for i in range(indexes.shape[0]):
end = indexes[i] + 1
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
start = sum(temp.int())
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))