add create config and policy for dexvla
This commit is contained in:
parent
20f346956a
commit
5701f02ea8
|
@ -3,3 +3,4 @@ from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfi
|
|||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .dexvla.configuration_dexvla import DexVLAConfig as DexVLAConfig
|
||||
|
|
|
@ -26,6 +26,7 @@ from lerobot.common.envs.utils import env_to_policy_features
|
|||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.dexvla.configuration_dexvla import DexVLAConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "dexvla":
|
||||
from lerobot.common.policies.dexvla.modeling_dexvla import DexVLAPolicy
|
||||
|
||||
return DexVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
@ -70,6 +75,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "dexvla":
|
||||
return DexVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue