add register policy_head and qwen2_vla
This commit is contained in:
parent
b4853011f8
commit
fcb2047310
|
@ -28,14 +28,18 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
register_policy_heads()
|
||||
register_qwen2_vla()
|
||||
|
||||
@PreTrainedConfig.register_subclass("dexvla")
|
||||
@dataclass
|
||||
class DexVLAConfig(PreTrainedConfig):
|
||||
# For loading policy head
|
||||
policy_head_type: str = "scale_dp_policy"
|
||||
policy_head_size: str = "ScaleDP_L"
|
||||
policy_head_size: str = "scaledp_l"
|
||||
action_dim: int = 14
|
||||
state_dim: int = 14
|
||||
chunk_size: int = 50
|
||||
|
|
Loading…
Reference in New Issue