add __init__.py

This commit is contained in:
lesjie-wen 2025-03-18 16:47:33 +08:00
parent 61e40435ae
commit b4853011f8
8 changed files with 27 additions and 12 deletions

View File

@ -0,0 +1,13 @@
from .configuration_scaledp import ScaleDPPolicyConfig
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
from .modeling_scaledp import ScaleDP
from .modeling_unet_diffusion import ConditionalUnet1D
from transformers import AutoConfig, AutoModel
def register_policy_heads():
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)

View File

@ -1,7 +1,7 @@
import os
from typing import Union
from transformers import AutoConfig, PretrainedConfig
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
@ -106,4 +106,3 @@ class ScaleDPPolicyConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)

View File

@ -1,7 +1,7 @@
import os
from typing import Union
from transformers import AutoConfig, PretrainedConfig
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
@ -70,4 +70,3 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)

View File

@ -11,7 +11,6 @@ import torch.nn.functional as func
import torch.utils.checkpoint
from timm.models.vision_transformer import Mlp, use_fused_attn
from torch.jit import Final
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
from .configuration_scaledp import ScaleDPPolicyConfig
@ -548,4 +547,3 @@ def scaledp_l(**kwargs):
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)

View File

@ -11,7 +11,6 @@ import torch.nn as nn
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
@ -376,4 +375,3 @@ class ConditionalUnet1D(PreTrainedModel):
return x
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)

View File

@ -0,0 +1,11 @@
from .configuration_qwen2_vla import Qwen2VLAConfig
from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
from transformers import AutoConfig, AutoModelForCausalLM
def register_qwen2_vla():
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)

View File

@ -16,7 +16,6 @@
import os
from typing import Union
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
@ -254,4 +253,3 @@ class Qwen2VLAConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)

View File

@ -28,7 +28,7 @@ import torch.nn as nn
import torch.nn.functional as func
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers import AutoConfig, AutoModel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
@ -2049,4 +2049,3 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
return model_inputs
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)