fix modeling
This commit is contained in:
parent
9f4d490423
commit
d927a90762
|
@ -19,7 +19,7 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
|||
action_dim=10,
|
||||
global_cond_dim=2048,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256, 512, 1024],
|
||||
down_dims=None,
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
state_dim=7,
|
||||
|
@ -29,6 +29,8 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
|||
num_train_timesteps=100,
|
||||
**kwargs,
|
||||
):
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
self.input_dim = action_dim
|
||||
self.noise_samples = noise_samples
|
||||
self.prediction_horizon = prediction_horizon
|
||||
|
|
|
@ -6,14 +6,9 @@ from typing import Tuple
|
|||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.functional as Func
|
||||
import torch.utils.checkpoint
|
||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||
from torch.jit import Final
|
||||
|
@ -51,13 +46,13 @@ class Attention(nn.Module):
|
|||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
b, n, c = x.shape
|
||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
x = Func.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
@ -79,7 +74,7 @@ class Attention(nn.Module):
|
|||
attn_scores += attn_mask
|
||||
|
||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
attn_weights = Func.softmax(attn_scores, dim=-1)
|
||||
|
||||
# Dropout on attention weights (if dropout is used)
|
||||
attn_weights = self.attn_drop(attn_weights)
|
||||
|
@ -87,7 +82,7 @@ class Attention(nn.Module):
|
|||
# Apply attention weights to value tensor (V)
|
||||
x = torch.matmul(attn_weights, v)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = x.transpose(1, 2).reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -162,7 +157,8 @@ class ScaleDPBlock(nn.Module):
|
|||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
def approx_gelu():
|
||||
return nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||
|
||||
|
@ -213,15 +209,15 @@ class ScaleDP(PreTrainedModel):
|
|||
# compute number of tokens for main trunk and conScaleDPion encoder
|
||||
if config.n_obs_steps is None:
|
||||
config.n_obs_steps = config.prediction_horizon
|
||||
T = config.prediction_horizon
|
||||
T_cond = 1
|
||||
t = config.prediction_horizon
|
||||
t_cond = 1
|
||||
if not config.time_as_cond:
|
||||
T += 1
|
||||
T_cond -= 1
|
||||
t += 1
|
||||
t_cond -= 1
|
||||
obs_as_cond = config.cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert config.time_as_cond
|
||||
T_cond += config.n_obs_steps
|
||||
t_cond += config.n_obs_steps
|
||||
|
||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||
self.combine = nn.Sequential(
|
||||
|
@ -254,8 +250,8 @@ class ScaleDP(PreTrainedModel):
|
|||
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
||||
# self.initialize_weights()
|
||||
# constants
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.t = t
|
||||
self.t_cond = t_cond
|
||||
self.prediction_horizon = config.prediction_horizon
|
||||
self.time_as_cond = config.time_as_cond
|
||||
self.action_dim = config.output_dim
|
||||
|
@ -328,7 +324,7 @@ class ScaleDP(PreTrainedModel):
|
|||
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
for pn, _p in m.named_parameters():
|
||||
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith("bias"):
|
||||
|
@ -345,7 +341,7 @@ class ScaleDP(PreTrainedModel):
|
|||
no_decay.add(fpn)
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
|
@ -360,11 +356,11 @@ class ScaleDP(PreTrainedModel):
|
|||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
|
|
@ -128,7 +128,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
||||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
||||
down_dims: Channel size for each UNet level.
|
||||
The length of this array determines numebr of levels.
|
||||
The length of this array determines number of levels.
|
||||
kernel_size: Conv kernel size
|
||||
n_groups: Number of groups for GroupNorm
|
||||
"""
|
||||
|
@ -301,7 +301,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
Tp = self.num_queries
|
||||
action_dim = 14
|
||||
|
||||
# initialize action from Guassian noise
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
|
|
|
@ -1937,12 +1937,12 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
inputs_index = inputs_index.int()
|
||||
|
||||
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
|
||||
indexs = torch.argmax((xor_array != 0).float(), dim=1)
|
||||
indexes = torch.argmax((xor_array != 0).float(), dim=1)
|
||||
input_embeddings = []
|
||||
reasoning_embeddings = []
|
||||
identity = []
|
||||
for i in range(indexs.shape[0]):
|
||||
end = indexs[i] + 1
|
||||
for i in range(indexes.shape[0]):
|
||||
end = indexes[i] + 1
|
||||
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
|
||||
start = sum(temp.int())
|
||||
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))
|
||||
|
|
Loading…
Reference in New Issue