feat(autopolicy): draft v0.4

This commit is contained in:
Steven Palma 2025-03-06 14:23:00 +01:00
parent caadc887ad
commit d4f9807ed0
No known key found for this signature in database
GPG Key ID: F4D348E38D8F8ADD
2 changed files with 120 additions and 94 deletions

View File

@ -13,60 +13,65 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
logger = logging.getLogger(__name__)
# Constants
IMPORT_PATHS = ["lerobot.common.policies.{0}.configuration_{0}"]
POLICY_IMPORT_PATHS = ["lerobot.common.policies.{0}.modeling_{0}"]
def policy_type_to_module_name(policy_type: str) -> str: def policy_type_to_module_name(policy_type: str) -> str:
"""Convert policy type to module name format.""" """
# TODO(Steven): Deal with this Convert policy type to module name format.
Args:
policy_type: The policy type identifier (e.g. 'lerobot/vqbet-pusht')
Returns:
str: Normalized module name (e.g. 'vqbet')
Examples:
>>> policy_type_to_module_name("lerobot/vqbet-pusht")
'vqbet'
"""
# TODO(Steven): This is a temporary solution, we should have a more robust way to handle this
return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "") return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "")
class _LazyPolicyConfigMapping(OrderedDict): class _LazyPolicyConfigMapping(OrderedDict):
""" def __init__(self, mapping: Dict[str, str]):
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping self._mapping = mapping
self._extra_content = {} self._extra_content: Dict[str, Any] = {}
self._modules = {} self._modules: Dict[str, Any] = {}
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
if key in self._extra_content: if key in self._extra_content:
return self._extra_content[key] return self._extra_content[key]
if key not in self._mapping: if key not in self._mapping:
raise KeyError(key) raise KeyError(f"Policy type '{key}' not found in mapping")
value = self._mapping[key] value = self._mapping[key]
module_name = policy_type_to_module_name(key) module_name = policy_type_to_module_name(key)
# Try standard import path first for import_path in IMPORT_PATHS:
try: try:
if key not in self._modules: if key not in self._modules:
print("Importing CONFIG: ",module_name) self._modules[key] = importlib.import_module(import_path.format(module_name))
self._modules[key] = importlib.import_module( logger.debug(f"Config module: {module_name} imported")
f"lerobot.common.policies.{module_name}.configuration_{module_name}" if hasattr(self._modules[key], value):
) return getattr(self._modules[key], value)
return getattr(self._modules[key], value) except ImportError:
except (ImportError, AttributeError): continue
# Try fallback paths
for import_path in [
f"lerobot.policies.{module_name}",
f"lerobot.common.policies.{module_name}",
]:
try:
print("Importing CONFIG: ",module_name)
self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value):
return getattr(self._modules[key], value)
except ImportError:
continue
raise ImportError(f"Could not find configuration class {value} for policy type {key}") raise ImportError(f"Could not find configuration class {value} for policy type {key}")
@ -109,12 +114,12 @@ class _LazyPolicyMapping(OrderedDict):
A dictionary that lazily loads its values when they are requested. A dictionary that lazily loads its values when they are requested.
""" """
def __init__(self, mapping): def __init__(self, mapping: Dict[str, str]):
self._mapping = mapping self._mapping = mapping
self._extra_content = {} self._extra_content: Dict[str, Type[PreTrainedPolicy]] = {}
self._modules = {} self._modules: Dict[str, Any] = {}
self._config_mapping = {} # Maps config classes to policy classes self._config_mapping: Dict[Type[PreTrainedConfig], Type[PreTrainedPolicy]] = {}
self._initialized_types = set() # Track which types have been initialized self._initialized_types: set[str] = set()
def _lazy_init_for_type(self, policy_type: str) -> None: def _lazy_init_for_type(self, policy_type: str) -> None:
"""Lazily initialize mappings for a policy type if not already done.""" """Lazily initialize mappings for a policy type if not already done."""
@ -124,41 +129,34 @@ class _LazyPolicyMapping(OrderedDict):
self._config_mapping[config_class] = self[policy_type] self._config_mapping[config_class] = self[policy_type]
self._initialized_types.add(policy_type) self._initialized_types.add(policy_type)
except (ImportError, AttributeError, KeyError) as e: except (ImportError, AttributeError, KeyError) as e:
import logging logger.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}")
logging.warning(
f"Could not automatically map config for policy type {policy_type}: {str(e)}"
)
def __getitem__(self, key): def __getitem__(self, key: str) -> Type[PreTrainedPolicy]:
"""Get a policy class by key with lazy loading."""
if key in self._extra_content: if key in self._extra_content:
return self._extra_content[key] return self._extra_content[key]
if key not in self._mapping: if key not in self._mapping:
raise KeyError(key) raise KeyError(f"Policy type '{key}' not found in mapping")
value = self._mapping[key] value = self._mapping[key]
module_name = policy_type_to_module_name(key) module_name = policy_type_to_module_name(key)
try: for import_path in POLICY_IMPORT_PATHS:
if key not in self._modules: try:
print("Importing POLICY: ", module_name) if key not in self._modules:
self._modules[key] = importlib.import_module( self._modules[key] = importlib.import_module(import_path.format(module_name))
f"lerobot.common.policies.{module_name}.modeling_{module_name}" logger.debug(
) f"Policy module: {module_name} imported from {import_path.format(module_name)}"
return getattr(self._modules[key], value) )
except (ImportError, AttributeError): if hasattr(self._modules[key], value):
for import_path in [ return getattr(self._modules[key], value)
f"lerobot.policies.{module_name}", except ImportError:
f"lerobot.common.policies.{module_name}", continue
]:
try:
print("Importing POLICY: ",module_name)
self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value):
return getattr(self._modules[key], value)
except ImportError:
continue
raise ImportError(f"Could not find policy class {value} for policy type {key}") raise ImportError(
f"Could not find policy class {value} for policy type {key}. "
f"Tried paths: {[p.format(module_name) for p in POLICY_IMPORT_PATHS]}"
)
def register( def register(
self, self,
@ -166,22 +164,29 @@ class _LazyPolicyMapping(OrderedDict):
value: Type[PreTrainedPolicy], value: Type[PreTrainedPolicy],
config_class: Type[PreTrainedConfig], config_class: Type[PreTrainedConfig],
exist_ok: bool = False, exist_ok: bool = False,
): ) -> None:
"""Register a new policy class with its configuration class.""" """Register a new policy class with its configuration class."""
if not isinstance(key, str):
raise TypeError(f"Key must be a string, got {type(key)}")
if not issubclass(value, PreTrainedPolicy):
raise TypeError(f"Value must be a PreTrainedPolicy subclass, got {type(value)}")
if not issubclass(config_class, PreTrainedConfig):
raise TypeError(f"Config class must be a PreTrainedConfig subclass, got {type(config_class)}")
if key in self._mapping and not exist_ok: if key in self._mapping and not exist_ok:
raise ValueError(f"'{key}' is already used by a Policy, pick another name.") raise ValueError(f"'{key}' is already used by a Policy, pick another name.")
self._extra_content[key] = value self._extra_content[key] = value
self._config_mapping[config_class] = value self._config_mapping[config_class] = value
def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]: def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]:
"""Get the policy class associated with a config class.""" """Get the policy class associated with a config class."""
# First check direct config class mapping # First check direct config class mapping
if type(config_class) in self._config_mapping: if config_class in self._config_mapping:
return self._config_mapping[type(config_class)] return self._config_mapping[config_class]
# Try to find by policy type # Try to find by policy type
try: try:
policy_type = config_class.type policy_type = config_class.get_type_str()
# Check extra content first # Check extra content first
if policy_type in self._extra_content: if policy_type in self._extra_content:
return self._extra_content[policy_type] return self._extra_content[policy_type]
@ -222,10 +227,7 @@ class AutoPolicyConfig:
""" """
def __init__(self): def __init__(self):
raise OSError( raise OSError("AutoPolicyConfig not meant to be instantiated directly")
"AutoPolicyConfig is designed to be instantiated "
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
)
@classmethod @classmethod
def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig: def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
@ -246,13 +248,12 @@ class AutoPolicyConfig:
policy_type (`str`): The policy type like "act" or "pi0". policy_type (`str`): The policy type like "act" or "pi0".
config ([`PreTrainedConfig`]): The config to register. config ([`PreTrainedConfig`]): The config to register.
""" """
# TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition if issubclass(config, PreTrainedConfig) and config.get_type_str() != policy_type:
# if issubclass(config, PreTrainedConfig) and config.type != policy_type: raise ValueError(
# raise ValueError( "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type " f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
# f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they " "match!"
# "match!" )
# )
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok) POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@classmethod @classmethod
@ -308,15 +309,12 @@ class AutoPolicy:
""" """
def __init__(self): def __init__(self):
raise OSError( raise OSError("AutoPolicy not meant to be instantiated directly")
"AutoPolicy is designed to be instantiated using the "
"`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods."
)
@classmethod @classmethod
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy: def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
"""Instantiate a policy from a configuration.""" """Instantiate a policy from a configuration."""
policy_class = POLICY_MAPPING.get_policy_for_config(config) policy_class = POLICY_MAPPING.get_policy_for_config(type(config))
return policy_class(config, **kwargs) return policy_class(config, **kwargs)
@classmethod @classmethod
@ -356,16 +354,33 @@ class AutoPolicy:
policy_class: The policy class to register policy_class: The policy class to register
exist_ok: Whether to allow overwriting existing registrations exist_ok: Whether to allow overwriting existing registrations
""" """
POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok) POLICY_MAPPING.register(config_class.get_type_str(), policy_class, config_class, exist_ok=exist_ok)
def main(): def main():
"""Test the AutoPolicy and AutoPolicyConfig functionality."""
# Simulates a build-in policy type being loaded def test_error_cases():
# Built-in policies work without explicit registration """Test error handling"""
try:
AutoPolicyConfig()
except OSError as e:
assert "not meant to be instantiated directly" in str(e)
try:
AutoPolicy()
except OSError as e:
assert "not meant to be instantiated directly" in str(e)
# config = AutoPolicyConfig.for_policy("vqbet") # try:
config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht") # AutoPolicy.from_config("invalid_config")
# except ValueError as e:
# assert "Unrecognized policy identifier" in str(e)
logging.basicConfig(level=logging.DEBUG)
# Test built-in policy loading
# config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
config = AutoPolicyConfig.for_policy("vqbet")
policy = AutoPolicy.from_config(config) policy = AutoPolicy.from_config(config)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
@ -374,8 +389,7 @@ def main():
assert isinstance(config, VQBeTConfig) assert isinstance(config, VQBeTConfig)
assert isinstance(policy, VQBeTPolicy) assert isinstance(policy, VQBeTPolicy)
# Simulates a new policy type being registered # Test policy registration
# Only new policies need registration
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@ -387,6 +401,9 @@ def main():
assert isinstance(my_new_config, TDMPCConfig) assert isinstance(my_new_config, TDMPCConfig)
assert isinstance(my_new_policy, TDMPCPolicy) assert isinstance(my_new_policy, TDMPCPolicy)
# Run error case tests
test_error_cases()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -47,6 +47,15 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
# TODO(Steven): Find a better way to do deal with this
@classmethod
def get_type_str(cls) -> str:
"""Get the policy type identifier for this configuration class."""
class_name = cls.__name__.lower()
if class_name.endswith("config"):
return class_name[:-6] # Remove 'config' suffix
return class_name
@abc.abstractproperty @abc.abstractproperty
def observation_delta_indices(self) -> list | None: def observation_delta_indices(self) -> list | None:
raise NotImplementedError raise NotImplementedError