add __init__.py
This commit is contained in:
parent
61e40435ae
commit
b4853011f8
|
@ -0,0 +1,13 @@
|
|||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
from .modeling_scaledp import ScaleDP
|
||||
from .modeling_unet_diffusion import ConditionalUnet1D
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
|
||||
def register_policy_heads():
|
||||
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
||||
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
||||
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
||||
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -106,4 +106,3 @@ class ScaleDPPolicyConfig(PretrainedConfig):
|
|||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -70,4 +70,3 @@ class UnetDiffusionPolicyConfig(PretrainedConfig):
|
|||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig)
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch.nn.functional as func
|
|||
import torch.utils.checkpoint
|
||||
from timm.models.vision_transformer import Mlp, use_fused_attn
|
||||
from torch.jit import Final
|
||||
from transformers import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_scaledp import ScaleDPPolicyConfig
|
||||
|
@ -548,4 +547,3 @@ def scaledp_l(**kwargs):
|
|||
return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs)
|
||||
|
||||
|
||||
AutoModel.register(ScaleDPPolicyConfig, ScaleDP)
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch.nn as nn
|
|||
|
||||
# requires diffusers==0.11.1
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from transformers import AutoModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from .configuration_unet_diffusion import UnetDiffusionPolicyConfig
|
||||
|
@ -376,4 +375,3 @@ class ConditionalUnet1D(PreTrainedModel):
|
|||
return x
|
||||
|
||||
|
||||
AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D)
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from .configuration_qwen2_vla import Qwen2VLAConfig
|
||||
from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
|
||||
def register_qwen2_vla():
|
||||
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
||||
|
||||
|
|
@ -16,7 +16,6 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
@ -254,4 +253,3 @@ class Qwen2VLAConfig(PretrainedConfig):
|
|||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
AutoConfig.register("qwen2_vla", Qwen2VLAConfig)
|
||||
|
|
|
@ -28,7 +28,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as func
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
|
||||
from transformers.generation import GenerationMixin
|
||||
|
@ -2049,4 +2049,3 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
return model_inputs
|
||||
|
||||
|
||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
||||
|
|
Loading…
Reference in New Issue