feat(autopolicy): draft v0.4
This commit is contained in:
parent
caadc887ad
commit
d4f9807ed0
|
@ -13,60 +13,65 @@
|
|||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
IMPORT_PATHS = ["lerobot.common.policies.{0}.configuration_{0}"]
|
||||
|
||||
POLICY_IMPORT_PATHS = ["lerobot.common.policies.{0}.modeling_{0}"]
|
||||
|
||||
|
||||
def policy_type_to_module_name(policy_type: str) -> str:
|
||||
"""Convert policy type to module name format."""
|
||||
# TODO(Steven): Deal with this
|
||||
"""
|
||||
Convert policy type to module name format.
|
||||
|
||||
Args:
|
||||
policy_type: The policy type identifier (e.g. 'lerobot/vqbet-pusht')
|
||||
|
||||
Returns:
|
||||
str: Normalized module name (e.g. 'vqbet')
|
||||
|
||||
Examples:
|
||||
>>> policy_type_to_module_name("lerobot/vqbet-pusht")
|
||||
'vqbet'
|
||||
"""
|
||||
# TODO(Steven): This is a temporary solution, we should have a more robust way to handle this
|
||||
return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "")
|
||||
|
||||
|
||||
class _LazyPolicyConfigMapping(OrderedDict):
|
||||
"""
|
||||
A dictionary that lazily load its values when they are requested.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping):
|
||||
def __init__(self, mapping: Dict[str, str]):
|
||||
self._mapping = mapping
|
||||
self._extra_content = {}
|
||||
self._modules = {}
|
||||
self._extra_content: Dict[str, Any] = {}
|
||||
self._modules: Dict[str, Any] = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
raise KeyError(key)
|
||||
raise KeyError(f"Policy type '{key}' not found in mapping")
|
||||
|
||||
value = self._mapping[key]
|
||||
module_name = policy_type_to_module_name(key)
|
||||
|
||||
# Try standard import path first
|
||||
try:
|
||||
if key not in self._modules:
|
||||
print("Importing CONFIG: ",module_name)
|
||||
self._modules[key] = importlib.import_module(
|
||||
f"lerobot.common.policies.{module_name}.configuration_{module_name}"
|
||||
)
|
||||
return getattr(self._modules[key], value)
|
||||
except (ImportError, AttributeError):
|
||||
# Try fallback paths
|
||||
for import_path in [
|
||||
f"lerobot.policies.{module_name}",
|
||||
f"lerobot.common.policies.{module_name}",
|
||||
]:
|
||||
try:
|
||||
print("Importing CONFIG: ",module_name)
|
||||
self._modules[key] = importlib.import_module(import_path)
|
||||
if hasattr(self._modules[key], value):
|
||||
return getattr(self._modules[key], value)
|
||||
except ImportError:
|
||||
continue
|
||||
for import_path in IMPORT_PATHS:
|
||||
try:
|
||||
if key not in self._modules:
|
||||
self._modules[key] = importlib.import_module(import_path.format(module_name))
|
||||
logger.debug(f"Config module: {module_name} imported")
|
||||
if hasattr(self._modules[key], value):
|
||||
return getattr(self._modules[key], value)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
raise ImportError(f"Could not find configuration class {value} for policy type {key}")
|
||||
|
||||
|
@ -109,12 +114,12 @@ class _LazyPolicyMapping(OrderedDict):
|
|||
A dictionary that lazily loads its values when they are requested.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping):
|
||||
def __init__(self, mapping: Dict[str, str]):
|
||||
self._mapping = mapping
|
||||
self._extra_content = {}
|
||||
self._modules = {}
|
||||
self._config_mapping = {} # Maps config classes to policy classes
|
||||
self._initialized_types = set() # Track which types have been initialized
|
||||
self._extra_content: Dict[str, Type[PreTrainedPolicy]] = {}
|
||||
self._modules: Dict[str, Any] = {}
|
||||
self._config_mapping: Dict[Type[PreTrainedConfig], Type[PreTrainedPolicy]] = {}
|
||||
self._initialized_types: set[str] = set()
|
||||
|
||||
def _lazy_init_for_type(self, policy_type: str) -> None:
|
||||
"""Lazily initialize mappings for a policy type if not already done."""
|
||||
|
@ -124,41 +129,34 @@ class _LazyPolicyMapping(OrderedDict):
|
|||
self._config_mapping[config_class] = self[policy_type]
|
||||
self._initialized_types.add(policy_type)
|
||||
except (ImportError, AttributeError, KeyError) as e:
|
||||
import logging
|
||||
logging.warning(
|
||||
f"Could not automatically map config for policy type {policy_type}: {str(e)}"
|
||||
)
|
||||
logger.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}")
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Type[PreTrainedPolicy]:
|
||||
"""Get a policy class by key with lazy loading."""
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
raise KeyError(key)
|
||||
raise KeyError(f"Policy type '{key}' not found in mapping")
|
||||
|
||||
value = self._mapping[key]
|
||||
module_name = policy_type_to_module_name(key)
|
||||
|
||||
try:
|
||||
if key not in self._modules:
|
||||
print("Importing POLICY: ", module_name)
|
||||
self._modules[key] = importlib.import_module(
|
||||
f"lerobot.common.policies.{module_name}.modeling_{module_name}"
|
||||
)
|
||||
return getattr(self._modules[key], value)
|
||||
except (ImportError, AttributeError):
|
||||
for import_path in [
|
||||
f"lerobot.policies.{module_name}",
|
||||
f"lerobot.common.policies.{module_name}",
|
||||
]:
|
||||
try:
|
||||
print("Importing POLICY: ",module_name)
|
||||
self._modules[key] = importlib.import_module(import_path)
|
||||
if hasattr(self._modules[key], value):
|
||||
return getattr(self._modules[key], value)
|
||||
except ImportError:
|
||||
continue
|
||||
for import_path in POLICY_IMPORT_PATHS:
|
||||
try:
|
||||
if key not in self._modules:
|
||||
self._modules[key] = importlib.import_module(import_path.format(module_name))
|
||||
logger.debug(
|
||||
f"Policy module: {module_name} imported from {import_path.format(module_name)}"
|
||||
)
|
||||
if hasattr(self._modules[key], value):
|
||||
return getattr(self._modules[key], value)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
raise ImportError(f"Could not find policy class {value} for policy type {key}")
|
||||
raise ImportError(
|
||||
f"Could not find policy class {value} for policy type {key}. "
|
||||
f"Tried paths: {[p.format(module_name) for p in POLICY_IMPORT_PATHS]}"
|
||||
)
|
||||
|
||||
def register(
|
||||
self,
|
||||
|
@ -166,26 +164,33 @@ class _LazyPolicyMapping(OrderedDict):
|
|||
value: Type[PreTrainedPolicy],
|
||||
config_class: Type[PreTrainedConfig],
|
||||
exist_ok: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Register a new policy class with its configuration class."""
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(f"Key must be a string, got {type(key)}")
|
||||
if not issubclass(value, PreTrainedPolicy):
|
||||
raise TypeError(f"Value must be a PreTrainedPolicy subclass, got {type(value)}")
|
||||
if not issubclass(config_class, PreTrainedConfig):
|
||||
raise TypeError(f"Config class must be a PreTrainedConfig subclass, got {type(config_class)}")
|
||||
|
||||
if key in self._mapping and not exist_ok:
|
||||
raise ValueError(f"'{key}' is already used by a Policy, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
self._config_mapping[config_class] = value
|
||||
|
||||
def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]:
|
||||
def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]:
|
||||
"""Get the policy class associated with a config class."""
|
||||
# First check direct config class mapping
|
||||
if type(config_class) in self._config_mapping:
|
||||
return self._config_mapping[type(config_class)]
|
||||
if config_class in self._config_mapping:
|
||||
return self._config_mapping[config_class]
|
||||
|
||||
# Try to find by policy type
|
||||
try:
|
||||
policy_type = config_class.type
|
||||
policy_type = config_class.get_type_str()
|
||||
# Check extra content first
|
||||
if policy_type in self._extra_content:
|
||||
return self._extra_content[policy_type]
|
||||
|
||||
|
||||
# Then check standard mapping
|
||||
if policy_type in self._mapping:
|
||||
self._lazy_init_for_type(policy_type)
|
||||
|
@ -222,10 +227,7 @@ class AutoPolicyConfig:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise OSError(
|
||||
"AutoPolicyConfig is designed to be instantiated "
|
||||
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
|
||||
)
|
||||
raise OSError("AutoPolicyConfig not meant to be instantiated directly")
|
||||
|
||||
@classmethod
|
||||
def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
|
||||
|
@ -246,13 +248,12 @@ class AutoPolicyConfig:
|
|||
policy_type (`str`): The policy type like "act" or "pi0".
|
||||
config ([`PreTrainedConfig`]): The config to register.
|
||||
"""
|
||||
# TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition
|
||||
# if issubclass(config, PreTrainedConfig) and config.type != policy_type:
|
||||
# raise ValueError(
|
||||
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
|
||||
# f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
|
||||
# "match!"
|
||||
# )
|
||||
if issubclass(config, PreTrainedConfig) and config.get_type_str() != policy_type:
|
||||
raise ValueError(
|
||||
"The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
|
||||
f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
|
||||
"match!"
|
||||
)
|
||||
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
|
||||
|
||||
@classmethod
|
||||
|
@ -308,15 +309,12 @@ class AutoPolicy:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise OSError(
|
||||
"AutoPolicy is designed to be instantiated using the "
|
||||
"`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods."
|
||||
)
|
||||
raise OSError("AutoPolicy not meant to be instantiated directly")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
|
||||
"""Instantiate a policy from a configuration."""
|
||||
policy_class = POLICY_MAPPING.get_policy_for_config(config)
|
||||
policy_class = POLICY_MAPPING.get_policy_for_config(type(config))
|
||||
return policy_class(config, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -356,16 +354,33 @@ class AutoPolicy:
|
|||
policy_class: The policy class to register
|
||||
exist_ok: Whether to allow overwriting existing registrations
|
||||
"""
|
||||
POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok)
|
||||
POLICY_MAPPING.register(config_class.get_type_str(), policy_class, config_class, exist_ok=exist_ok)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Simulates a build-in policy type being loaded
|
||||
# Built-in policies work without explicit registration
|
||||
"""Test the AutoPolicy and AutoPolicyConfig functionality."""
|
||||
|
||||
# config = AutoPolicyConfig.for_policy("vqbet")
|
||||
config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
|
||||
def test_error_cases():
|
||||
"""Test error handling"""
|
||||
try:
|
||||
AutoPolicyConfig()
|
||||
except OSError as e:
|
||||
assert "not meant to be instantiated directly" in str(e)
|
||||
try:
|
||||
AutoPolicy()
|
||||
except OSError as e:
|
||||
assert "not meant to be instantiated directly" in str(e)
|
||||
|
||||
# try:
|
||||
# AutoPolicy.from_config("invalid_config")
|
||||
# except ValueError as e:
|
||||
# assert "Unrecognized policy identifier" in str(e)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Test built-in policy loading
|
||||
# config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
|
||||
config = AutoPolicyConfig.for_policy("vqbet")
|
||||
policy = AutoPolicy.from_config(config)
|
||||
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
@ -374,8 +389,7 @@ def main():
|
|||
assert isinstance(config, VQBeTConfig)
|
||||
assert isinstance(policy, VQBeTPolicy)
|
||||
|
||||
# Simulates a new policy type being registered
|
||||
# Only new policies need registration
|
||||
# Test policy registration
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
|
@ -387,6 +401,9 @@ def main():
|
|||
assert isinstance(my_new_config, TDMPCConfig)
|
||||
assert isinstance(my_new_policy, TDMPCPolicy)
|
||||
|
||||
# Run error case tests
|
||||
test_error_cases()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -47,6 +47,15 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
# TODO(Steven): Find a better way to do deal with this
|
||||
@classmethod
|
||||
def get_type_str(cls) -> str:
|
||||
"""Get the policy type identifier for this configuration class."""
|
||||
class_name = cls.__name__.lower()
|
||||
if class_name.endswith("config"):
|
||||
return class_name[:-6] # Remove 'config' suffix
|
||||
return class_name
|
||||
|
||||
@abc.abstractproperty
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue