remove unused code

This commit is contained in:
lesjie-wen 2025-02-20 18:29:01 +08:00
parent 5701f02ea8
commit 15fdeb382c
4 changed files with 11 additions and 69 deletions

View File

@ -208,7 +208,7 @@ class DexVLAConfig(PretrainedConfig):
policy_head_size='DiT_L',
action_dim=10,
state_dim=7,
non_lora_lr=1e-4,
chunk_size=50,
**kwargs,
):
if isinstance(vision_config, dict):
@ -228,18 +228,11 @@ class DexVLAConfig(PretrainedConfig):
# for loading policy head
self.policy_head_type = policy_head_type
# if policy_head_type == 'dit_diffusion_policy':
# # self.policy_head_size = kwargs.get("policy_head_size", "none")
# self.policy_head_size = policy_head_size
# # self.policy_head_config = register_configuration_class(self.policy_head_type, model_size=policy_head_size)
# self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type,
# model_size=self.policy_head_size,
# global_cond_dim=hidden_size, action_dim=action_dim,
# state_dim=state_dim)
# elif policy_head_type == 'unet_diffusion_policy':
# self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type,
# global_cond_dim=hidden_size, action_dim=action_dim,
# state_dim=state_dim)
self.policy_head_config = AutoConfig.for_model(model_type=policy_head_type,
model_size=policy_head_size,
cond_dim=hidden_size, action_dim=action_dim,
prediction_horizon=chunk_size,
state_dim=state_dim)
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
@ -266,5 +259,3 @@ class DexVLAConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
from transformers import AutoConfig
AutoConfig.register("dex_vla", DexVLAConfig)

View File

@ -1455,12 +1455,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
self.vocab_size = config.vocab_size
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.with_llm_head = config.with_llm_head
# self.with_external_vit = config.with_external_vit
self.with_text_fcs = config.with_text_fcs
self.only_using_input_embeddings = config.only_using_input_embeddings
self.using_film = config.using_film
self.using_xattn = config.using_xattn
self.llm_loss_weight = config.llm_loss_weight
@ -1472,8 +1466,8 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
# Initialize weights and apply final processing
self.post_init()
if config.policy_head_config.model_type == "dit_diffusion_policy":
self.policy_head.init_weights()
self.policy_head.init_weights()
self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size)
if self.using_film:
@ -1688,7 +1682,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
states: Optional[torch.FloatTensor] = None,
is_pad: bool = False,
is_eval: bool = False,
tinyvla: bool = False,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
@ -1802,8 +1795,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
)
hidden_states = outputs[0]
if tinyvla and is_eval: # dex-vla supports tinyvla-style VLA
return hidden_states
logits = self.lm_head(hidden_states)
logits = logits.float()
@ -1839,11 +1830,9 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
rope_deltas=rope_deltas,
)
if self.using_film:
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
hidden_states=hidden_states)
else: # tinyvla
action_hidden_states = hidden_states
ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad)
@ -2041,30 +2030,3 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
action = self.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad)
return action, outputs_text
def evaluate_tinyvla(self,
input_ids: torch.LongTensor = None,
actions=None,
states=None,
is_pad=None,
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
):
input_ids = input_ids.to('cuda')
with torch.inference_mode():
all_hidden_states = self.forward(input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
is_eval=is_eval,
tinyvla=True)
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
action = self.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad)
return action, "tinyvla no output"
from transformers import AutoModelForCausalLM
AutoModelForCausalLM.register(DexVLAConfig, DexVLAPolicy)

View File

@ -34,10 +34,8 @@ class ScaleDPPolicyConfig(PretrainedConfig):
learn_sigma: bool = False,
model_size: str = "none",
num_inference_timesteps: int = 10,
num_queries: int = 16,
noise_samples: int = 1,
num_train_timesteps: int = 100,
is_tinyvla: bool = False,
**kwargs
):
if model_size != "none":
@ -54,7 +52,6 @@ class ScaleDPPolicyConfig(PretrainedConfig):
self.output_dim = action_dim
self.prediction_horizon = prediction_horizon
self.is_tinyvla = is_tinyvla
self.cond_dim = cond_dim
self.state_dim = state_dim

View File

@ -219,10 +219,6 @@ class ScaleDP(PreTrainedModel):
assert config.time_as_cond
T_cond += config.n_obs_steps
self.is_tinyvla = config.is_tinyvla
if config.is_tinyvla:
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
self.norm_after_pool = nn.LayerNorm(config.cond_dim)
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
self.combine = nn.Sequential(
nn.Linear(config.cond_dim+config.state_dim, 1024),
@ -456,11 +452,7 @@ class ScaleDP(PreTrainedModel):
t: (N,) tensor of diffusion timesteps
global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings
"""
if self.is_tinyvla:
global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
global_cond = self.norm_after_pool(global_cond)
else:
global_cond = global_cond.squeeze(1)
global_cond = global_cond.squeeze(1)
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
global_cond = self.combine(global_cond)