remove unused code
This commit is contained in:
parent
5701f02ea8
commit
15fdeb382c
|
@ -208,7 +208,7 @@ class DexVLAConfig(PretrainedConfig):
|
|||
policy_head_size='DiT_L',
|
||||
action_dim=10,
|
||||
state_dim=7,
|
||||
non_lora_lr=1e-4,
|
||||
chunk_size=50,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
|
@ -228,18 +228,11 @@ class DexVLAConfig(PretrainedConfig):
|
|||
|
||||
# for loading policy head
|
||||
self.policy_head_type = policy_head_type
|
||||
# if policy_head_type == 'dit_diffusion_policy':
|
||||
# # self.policy_head_size = kwargs.get("policy_head_size", "none")
|
||||
# self.policy_head_size = policy_head_size
|
||||
# # self.policy_head_config = register_configuration_class(self.policy_head_type, model_size=policy_head_size)
|
||||
# self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type,
|
||||
# model_size=self.policy_head_size,
|
||||
# global_cond_dim=hidden_size, action_dim=action_dim,
|
||||
# state_dim=state_dim)
|
||||
# elif policy_head_type == 'unet_diffusion_policy':
|
||||
# self.policy_head_config = AutoConfig.for_model(model_type=self.policy_head_type,
|
||||
# global_cond_dim=hidden_size, action_dim=action_dim,
|
||||
# state_dim=state_dim)
|
||||
self.policy_head_config = AutoConfig.for_model(model_type=policy_head_type,
|
||||
model_size=policy_head_size,
|
||||
cond_dim=hidden_size, action_dim=action_dim,
|
||||
prediction_horizon=chunk_size,
|
||||
state_dim=state_dim)
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
@ -266,5 +259,3 @@ class DexVLAConfig(PretrainedConfig):
|
|||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
from transformers import AutoConfig
|
||||
AutoConfig.register("dex_vla", DexVLAConfig)
|
||||
|
|
|
@ -1455,12 +1455,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
||||
self.with_llm_head = config.with_llm_head
|
||||
# self.with_external_vit = config.with_external_vit
|
||||
self.with_text_fcs = config.with_text_fcs
|
||||
self.only_using_input_embeddings = config.only_using_input_embeddings
|
||||
self.using_film = config.using_film
|
||||
self.using_xattn = config.using_xattn
|
||||
|
||||
self.llm_loss_weight = config.llm_loss_weight
|
||||
|
||||
|
@ -1472,8 +1466,8 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
if config.policy_head_config.model_type == "dit_diffusion_policy":
|
||||
self.policy_head.init_weights()
|
||||
|
||||
self.policy_head.init_weights()
|
||||
self.input_action_proj = ActionProjector(config.hidden_size, config.hidden_size)
|
||||
|
||||
if self.using_film:
|
||||
|
@ -1688,7 +1682,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
states: Optional[torch.FloatTensor] = None,
|
||||
is_pad: bool = False,
|
||||
is_eval: bool = False,
|
||||
tinyvla: bool = False,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -1802,8 +1795,6 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if tinyvla and is_eval: # dex-vla supports tinyvla-style VLA
|
||||
return hidden_states
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
@ -1839,11 +1830,9 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
rope_deltas=rope_deltas,
|
||||
)
|
||||
|
||||
if self.using_film:
|
||||
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
|
||||
action_hidden_states = self.film_forward(labels=labels, input_ids=input_ids,
|
||||
hidden_states=hidden_states)
|
||||
else: # tinyvla
|
||||
action_hidden_states = hidden_states
|
||||
|
||||
|
||||
ret = self.policy_head(actions=actions, hidden_states=action_hidden_states, states=states, is_pad=is_pad)
|
||||
|
||||
|
@ -2041,30 +2030,3 @@ class DexVLAPolicy(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
action = self.policy_head(actions, action_hidden_states, states.to(all_hidden_states.dtype), is_pad)
|
||||
return action, outputs_text
|
||||
|
||||
def evaluate_tinyvla(self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
actions=None,
|
||||
states=None,
|
||||
is_pad=None,
|
||||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_thw=None,
|
||||
):
|
||||
input_ids = input_ids.to('cuda')
|
||||
with torch.inference_mode():
|
||||
all_hidden_states = self.forward(input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_thw=image_grid_thw,
|
||||
is_eval=is_eval,
|
||||
tinyvla=True)
|
||||
|
||||
all_hidden_states = torch.mean(all_hidden_states, dim=1).unsqueeze(1)
|
||||
|
||||
action = self.policy_head(actions, all_hidden_states, states.to(all_hidden_states.dtype), is_pad)
|
||||
return action, "tinyvla no output"
|
||||
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
AutoModelForCausalLM.register(DexVLAConfig, DexVLAPolicy)
|
||||
|
|
|
@ -34,10 +34,8 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
|||
learn_sigma: bool = False,
|
||||
model_size: str = "none",
|
||||
num_inference_timesteps: int = 10,
|
||||
num_queries: int = 16,
|
||||
noise_samples: int = 1,
|
||||
num_train_timesteps: int = 100,
|
||||
is_tinyvla: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
if model_size != "none":
|
||||
|
@ -54,7 +52,6 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
|||
self.output_dim = action_dim
|
||||
self.prediction_horizon = prediction_horizon
|
||||
|
||||
self.is_tinyvla = is_tinyvla
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
self.state_dim = state_dim
|
||||
|
|
|
@ -219,10 +219,6 @@ class ScaleDP(PreTrainedModel):
|
|||
assert config.time_as_cond
|
||||
T_cond += config.n_obs_steps
|
||||
|
||||
self.is_tinyvla = config.is_tinyvla
|
||||
if config.is_tinyvla:
|
||||
self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
|
||||
self.norm_after_pool = nn.LayerNorm(config.cond_dim)
|
||||
# self.combine = nn.Linear(cond_dim+state_dim, cond_dim)
|
||||
self.combine = nn.Sequential(
|
||||
nn.Linear(config.cond_dim+config.state_dim, 1024),
|
||||
|
@ -456,11 +452,7 @@ class ScaleDP(PreTrainedModel):
|
|||
t: (N,) tensor of diffusion timesteps
|
||||
global_cond: (N, n_obs_steps, D) tensor of conScaleDPions: image embeddings
|
||||
"""
|
||||
if self.is_tinyvla:
|
||||
global_cond = self.global_1d_pool(global_cond.permute(0, 2, 1)).squeeze(-1)
|
||||
global_cond = self.norm_after_pool(global_cond)
|
||||
else:
|
||||
global_cond = global_cond.squeeze(1)
|
||||
global_cond = global_cond.squeeze(1)
|
||||
global_cond = torch.cat([global_cond, states], dim=-1) if states is not None else global_cond
|
||||
global_cond = self.combine(global_cond)
|
||||
|
||||
|
|
Loading…
Reference in New Issue