diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index 8a54c0d2..5c1f6743 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -28,14 +28,18 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode logger = logging.get_logger(__name__) +from .policy_heads import register_policy_heads +from .qwe2_vla import register_qwen2_vla +register_policy_heads() +register_qwen2_vla() @PreTrainedConfig.register_subclass("dexvla") @dataclass class DexVLAConfig(PreTrainedConfig): # For loading policy head policy_head_type: str = "scale_dp_policy" - policy_head_size: str = "ScaleDP_L" + policy_head_size: str = "scaledp_l" action_dim: int = 14 state_dim: int = 14 chunk_size: int = 50