diff --git a/lerobot/common/policies/dexvla/policy_heads/__init__.py b/lerobot/common/policies/dexvla/policy_heads/__init__.py new file mode 100644 index 00000000..991e8e0b --- /dev/null +++ b/lerobot/common/policies/dexvla/policy_heads/__init__.py @@ -0,0 +1,13 @@ +from .configuration_scaledp import ScaleDPPolicyConfig +from .configuration_unet_diffusion import UnetDiffusionPolicyConfig +from .modeling_scaledp import ScaleDP +from .modeling_unet_diffusion import ConditionalUnet1D +from transformers import AutoConfig, AutoModel + + +def register_policy_heads(): + AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) + AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) + AutoModel.register(ScaleDPPolicyConfig, ScaleDP) + AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) + diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py index 385c8dc1..8ccc6196 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_scaledp.py @@ -1,7 +1,7 @@ import os from typing import Union -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) @@ -106,4 +106,3 @@ class ScaleDPPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) -AutoConfig.register("scale_dp_policy", ScaleDPPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py index 6ca6fcbe..3c9af4d9 100644 --- a/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/configuration_unet_diffusion.py @@ -1,7 +1,7 @@ import os from typing import Union -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) @@ -70,4 +70,3 @@ class UnetDiffusionPolicyConfig(PretrainedConfig): return cls.from_dict(config_dict, **kwargs) -AutoConfig.register("unet_diffusion_policy", UnetDiffusionPolicyConfig) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py index b09f5d24..5b8217fb 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_scaledp.py @@ -11,7 +11,6 @@ import torch.nn.functional as func import torch.utils.checkpoint from timm.models.vision_transformer import Mlp, use_fused_attn from torch.jit import Final -from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel from .configuration_scaledp import ScaleDPPolicyConfig @@ -548,4 +547,3 @@ def scaledp_l(**kwargs): return ScaleDP(depth=24, n_emb=1024, num_heads=16, **kwargs) -AutoModel.register(ScaleDPPolicyConfig, ScaleDP) diff --git a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py index 9a6a5f98..dc227ccb 100644 --- a/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py +++ b/lerobot/common/policies/dexvla/policy_heads/modeling_unet_diffusion.py @@ -11,7 +11,6 @@ import torch.nn as nn # requires diffusers==0.11.1 from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from transformers import AutoModel from transformers.modeling_utils import PreTrainedModel from .configuration_unet_diffusion import UnetDiffusionPolicyConfig @@ -376,4 +375,3 @@ class ConditionalUnet1D(PreTrainedModel): return x -AutoModel.register(UnetDiffusionPolicyConfig, ConditionalUnet1D) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/__init__.py b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py new file mode 100644 index 00000000..23c7b636 --- /dev/null +++ b/lerobot/common/policies/dexvla/qwe2_vla/__init__.py @@ -0,0 +1,11 @@ +from .configuration_qwen2_vla import Qwen2VLAConfig +from .modeling_qwen2_vla import Qwen2VLForConditionalGenerationForVLA + +from transformers import AutoConfig, AutoModelForCausalLM + + +def register_qwen2_vla(): + AutoConfig.register("qwen2_vla", Qwen2VLAConfig) + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) + + diff --git a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py index 80717bc2..e3ea55b8 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/configuration_qwen2_vla.py @@ -16,7 +16,6 @@ import os from typing import Union -from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging @@ -254,4 +253,3 @@ class Qwen2VLAConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -AutoConfig.register("qwen2_vla", Qwen2VLAConfig) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index 0fd81253..fa06a7b3 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as func import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers import AutoConfig, AutoModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin @@ -2049,4 +2049,3 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs -AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)