fix modeling
This commit is contained in:
parent
9f4d490423
commit
d927a90762
|
@ -19,7 +19,7 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
action_dim=10,
|
action_dim=10,
|
||||||
global_cond_dim=2048,
|
global_cond_dim=2048,
|
||||||
diffusion_step_embed_dim=256,
|
diffusion_step_embed_dim=256,
|
||||||
down_dims=[256, 512, 1024],
|
down_dims=None,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
n_groups=8,
|
n_groups=8,
|
||||||
state_dim=7,
|
state_dim=7,
|
||||||
|
@ -29,6 +29,8 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
||||||
num_train_timesteps=100,
|
num_train_timesteps=100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if down_dims is None:
|
||||||
|
down_dims = [256, 512, 1024]
|
||||||
self.input_dim = action_dim
|
self.input_dim = action_dim
|
||||||
self.noise_samples = noise_samples
|
self.noise_samples = noise_samples
|
||||||
self.prediction_horizon = prediction_horizon
|
self.prediction_horizon = prediction_horizon
|
||||||
|
|
|
@ -6,14 +6,9 @@ from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:
|
|
||||||
from typing import Literal
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as Func
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||||
from torch.jit import Final
|
from torch.jit import Final
|
||||||
|
@ -51,13 +46,13 @@ class Attention(nn.Module):
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
b, n, c = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv.unbind(0)
|
q, k, v = qkv.unbind(0)
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = Func.scaled_dot_product_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
@ -79,7 +74,7 @@ class Attention(nn.Module):
|
||||||
attn_scores += attn_mask
|
attn_scores += attn_mask
|
||||||
|
|
||||||
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
# Apply softmax to get attention weights (softmax is applied along the last dimension)
|
||||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
attn_weights = Func.softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
# Dropout on attention weights (if dropout is used)
|
# Dropout on attention weights (if dropout is used)
|
||||||
attn_weights = self.attn_drop(attn_weights)
|
attn_weights = self.attn_drop(attn_weights)
|
||||||
|
@ -87,7 +82,7 @@ class Attention(nn.Module):
|
||||||
# Apply attention weights to value tensor (V)
|
# Apply attention weights to value tensor (V)
|
||||||
x = torch.matmul(attn_weights, v)
|
x = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(b, n, c)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -162,7 +157,8 @@ class ScaleDPBlock(nn.Module):
|
||||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
def approx_gelu():
|
||||||
|
return nn.GELU(approximate="tanh")
|
||||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
@ -213,15 +209,15 @@ class ScaleDP(PreTrainedModel):
|
||||||
# compute number of tokens for main trunk and conScaleDPion encoder
|
# compute number of tokens for main trunk and conScaleDPion encoder
|
||||||
if config.n_obs_steps is None:
|
if config.n_obs_steps is None:
|
||||||
config.n_obs_steps = config.prediction_horizon
|
config.n_obs_steps = config.prediction_horizon
|
||||||
T = config.prediction_horizon
|
t = config.prediction_horizon
|
||||||
T_cond = 1
|
t_cond = 1
|
||||||
if not config.time_as_cond:
|
if not config.time_as_cond:
|
||||||
T += 1
|
t += 1
|
||||||
T_cond -= 1
|
t_cond -= 1
|
||||||
obs_as_cond = config.cond_dim > 0
|
obs_as_cond = config.cond_dim > 0
|
||||||
if obs_as_cond:
|
if obs_as_cond:
|
||||||
assert config.time_as_cond
|
assert config.time_as_cond
|
||||||
T_cond += config.n_obs_steps
|
t_cond += config.n_obs_steps
|
||||||
|
|
||||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||||
self.combine = nn.Sequential(
|
self.combine = nn.Sequential(
|
||||||
|
@ -254,8 +250,8 @@ class ScaleDP(PreTrainedModel):
|
||||||
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
self.final_layer = FinalLayer(config.n_emb, output_dim=config.output_dim)
|
||||||
# self.initialize_weights()
|
# self.initialize_weights()
|
||||||
# constants
|
# constants
|
||||||
self.T = T
|
self.t = t
|
||||||
self.T_cond = T_cond
|
self.t_cond = t_cond
|
||||||
self.prediction_horizon = config.prediction_horizon
|
self.prediction_horizon = config.prediction_horizon
|
||||||
self.time_as_cond = config.time_as_cond
|
self.time_as_cond = config.time_as_cond
|
||||||
self.action_dim = config.output_dim
|
self.action_dim = config.output_dim
|
||||||
|
@ -328,7 +324,7 @@ class ScaleDP(PreTrainedModel):
|
||||||
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
whitelist_weight_modules = (torch.nn.Linear, Attention)
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
for mn, m in self.named_modules():
|
for mn, m in self.named_modules():
|
||||||
for pn, p in m.named_parameters():
|
for pn, _p in m.named_parameters():
|
||||||
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
||||||
|
|
||||||
if pn.endswith("bias"):
|
if pn.endswith("bias"):
|
||||||
|
@ -345,7 +341,7 @@ class ScaleDP(PreTrainedModel):
|
||||||
no_decay.add(fpn)
|
no_decay.add(fpn)
|
||||||
|
|
||||||
# validate that we considered every parameter
|
# validate that we considered every parameter
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
param_dict = dict(self.named_parameters())
|
||||||
inter_params = decay & no_decay
|
inter_params = decay & no_decay
|
||||||
union_params = decay | no_decay
|
union_params = decay | no_decay
|
||||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||||
|
@ -360,11 +356,11 @@ class ScaleDP(PreTrainedModel):
|
||||||
# create the pytorch optimizer object
|
# create the pytorch optimizer object
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{
|
{
|
||||||
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||||
"weight_decay": weight_decay,
|
"weight_decay": weight_decay,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
|
@ -128,7 +128,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
||||||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
||||||
down_dims: Channel size for each UNet level.
|
down_dims: Channel size for each UNet level.
|
||||||
The length of this array determines numebr of levels.
|
The length of this array determines number of levels.
|
||||||
kernel_size: Conv kernel size
|
kernel_size: Conv kernel size
|
||||||
n_groups: Number of groups for GroupNorm
|
n_groups: Number of groups for GroupNorm
|
||||||
"""
|
"""
|
||||||
|
@ -301,7 +301,7 @@ class ConditionalUnet1D(PreTrainedModel):
|
||||||
Tp = self.num_queries
|
Tp = self.num_queries
|
||||||
action_dim = 14
|
action_dim = 14
|
||||||
|
|
||||||
# initialize action from Guassian noise
|
# initialize action from Gaussian noise
|
||||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
||||||
|
|
||||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||||
|
|
|
@ -1937,12 +1937,12 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
inputs_index = inputs_index.int()
|
inputs_index = inputs_index.int()
|
||||||
|
|
||||||
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
|
xor_array = torch.bitwise_xor(inputs_index[:, :-1], inputs_index[:, 1:])
|
||||||
indexs = torch.argmax((xor_array != 0).float(), dim=1)
|
indexes = torch.argmax((xor_array != 0).float(), dim=1)
|
||||||
input_embeddings = []
|
input_embeddings = []
|
||||||
reasoning_embeddings = []
|
reasoning_embeddings = []
|
||||||
identity = []
|
identity = []
|
||||||
for i in range(indexs.shape[0]):
|
for i in range(indexes.shape[0]):
|
||||||
end = indexs[i] + 1
|
end = indexes[i] + 1
|
||||||
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
|
temp = input_ids[i] == 151643 # pad token id for qwen2_vl
|
||||||
start = sum(temp.int())
|
start = sum(temp.int())
|
||||||
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))
|
input_embeddings.append(self.input_action_proj(hidden_states[i, start:end, :]))
|
||||||
|
|
Loading…
Reference in New Issue