feat(autopolicy): draft v0.4

This commit is contained in:
Steven Palma 2025-03-06 14:23:00 +01:00
parent caadc887ad
commit d4f9807ed0
No known key found for this signature in database
GPG Key ID: F4D348E38D8F8ADD
2 changed files with 120 additions and 94 deletions

View File

@ -13,56 +13,61 @@
# limitations under the License.
import importlib
import logging
import os
from collections import OrderedDict
from pathlib import Path
from typing import Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.configs.policies import PreTrainedConfig
logger = logging.getLogger(__name__)
# Constants
IMPORT_PATHS = ["lerobot.common.policies.{0}.configuration_{0}"]
POLICY_IMPORT_PATHS = ["lerobot.common.policies.{0}.modeling_{0}"]
def policy_type_to_module_name(policy_type: str) -> str:
"""Convert policy type to module name format."""
# TODO(Steven): Deal with this
"""
Convert policy type to module name format.
Args:
policy_type: The policy type identifier (e.g. 'lerobot/vqbet-pusht')
Returns:
str: Normalized module name (e.g. 'vqbet')
Examples:
>>> policy_type_to_module_name("lerobot/vqbet-pusht")
'vqbet'
"""
# TODO(Steven): This is a temporary solution, we should have a more robust way to handle this
return policy_type.replace("lerobot/", "").replace("-", "_").replace("_", "").replace("pusht", "")
class _LazyPolicyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
def __init__(self, mapping: Dict[str, str]):
self._mapping = mapping
self._extra_content = {}
self._modules = {}
self._extra_content: Dict[str, Any] = {}
self._modules: Dict[str, Any] = {}
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
if key in self._extra_content:
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)
raise KeyError(f"Policy type '{key}' not found in mapping")
value = self._mapping[key]
module_name = policy_type_to_module_name(key)
# Try standard import path first
for import_path in IMPORT_PATHS:
try:
if key not in self._modules:
print("Importing CONFIG: ",module_name)
self._modules[key] = importlib.import_module(
f"lerobot.common.policies.{module_name}.configuration_{module_name}"
)
return getattr(self._modules[key], value)
except (ImportError, AttributeError):
# Try fallback paths
for import_path in [
f"lerobot.policies.{module_name}",
f"lerobot.common.policies.{module_name}",
]:
try:
print("Importing CONFIG: ",module_name)
self._modules[key] = importlib.import_module(import_path)
self._modules[key] = importlib.import_module(import_path.format(module_name))
logger.debug(f"Config module: {module_name} imported")
if hasattr(self._modules[key], value):
return getattr(self._modules[key], value)
except ImportError:
@ -109,12 +114,12 @@ class _LazyPolicyMapping(OrderedDict):
A dictionary that lazily loads its values when they are requested.
"""
def __init__(self, mapping):
def __init__(self, mapping: Dict[str, str]):
self._mapping = mapping
self._extra_content = {}
self._modules = {}
self._config_mapping = {} # Maps config classes to policy classes
self._initialized_types = set() # Track which types have been initialized
self._extra_content: Dict[str, Type[PreTrainedPolicy]] = {}
self._modules: Dict[str, Any] = {}
self._config_mapping: Dict[Type[PreTrainedConfig], Type[PreTrainedPolicy]] = {}
self._initialized_types: set[str] = set()
def _lazy_init_for_type(self, policy_type: str) -> None:
"""Lazily initialize mappings for a policy type if not already done."""
@ -124,41 +129,34 @@ class _LazyPolicyMapping(OrderedDict):
self._config_mapping[config_class] = self[policy_type]
self._initialized_types.add(policy_type)
except (ImportError, AttributeError, KeyError) as e:
import logging
logging.warning(
f"Could not automatically map config for policy type {policy_type}: {str(e)}"
)
logger.warning(f"Could not automatically map config for policy type {policy_type}: {str(e)}")
def __getitem__(self, key):
def __getitem__(self, key: str) -> Type[PreTrainedPolicy]:
"""Get a policy class by key with lazy loading."""
if key in self._extra_content:
return self._extra_content[key]
if key not in self._mapping:
raise KeyError(key)
raise KeyError(f"Policy type '{key}' not found in mapping")
value = self._mapping[key]
module_name = policy_type_to_module_name(key)
for import_path in POLICY_IMPORT_PATHS:
try:
if key not in self._modules:
print("Importing POLICY: ", module_name)
self._modules[key] = importlib.import_module(
f"lerobot.common.policies.{module_name}.modeling_{module_name}"
self._modules[key] = importlib.import_module(import_path.format(module_name))
logger.debug(
f"Policy module: {module_name} imported from {import_path.format(module_name)}"
)
return getattr(self._modules[key], value)
except (ImportError, AttributeError):
for import_path in [
f"lerobot.policies.{module_name}",
f"lerobot.common.policies.{module_name}",
]:
try:
print("Importing POLICY: ",module_name)
self._modules[key] = importlib.import_module(import_path)
if hasattr(self._modules[key], value):
return getattr(self._modules[key], value)
except ImportError:
continue
raise ImportError(f"Could not find policy class {value} for policy type {key}")
raise ImportError(
f"Could not find policy class {value} for policy type {key}. "
f"Tried paths: {[p.format(module_name) for p in POLICY_IMPORT_PATHS]}"
)
def register(
self,
@ -166,22 +164,29 @@ class _LazyPolicyMapping(OrderedDict):
value: Type[PreTrainedPolicy],
config_class: Type[PreTrainedConfig],
exist_ok: bool = False,
):
) -> None:
"""Register a new policy class with its configuration class."""
if not isinstance(key, str):
raise TypeError(f"Key must be a string, got {type(key)}")
if not issubclass(value, PreTrainedPolicy):
raise TypeError(f"Value must be a PreTrainedPolicy subclass, got {type(value)}")
if not issubclass(config_class, PreTrainedConfig):
raise TypeError(f"Config class must be a PreTrainedConfig subclass, got {type(config_class)}")
if key in self._mapping and not exist_ok:
raise ValueError(f"'{key}' is already used by a Policy, pick another name.")
self._extra_content[key] = value
self._config_mapping[config_class] = value
def get_policy_for_config(self, config_class: PreTrainedConfig) -> Type[PreTrainedPolicy]:
def get_policy_for_config(self, config_class: Type[PreTrainedConfig]) -> Type[PreTrainedPolicy]:
"""Get the policy class associated with a config class."""
# First check direct config class mapping
if type(config_class) in self._config_mapping:
return self._config_mapping[type(config_class)]
if config_class in self._config_mapping:
return self._config_mapping[config_class]
# Try to find by policy type
try:
policy_type = config_class.type
policy_type = config_class.get_type_str()
# Check extra content first
if policy_type in self._extra_content:
return self._extra_content[policy_type]
@ -222,10 +227,7 @@ class AutoPolicyConfig:
"""
def __init__(self):
raise OSError(
"AutoPolicyConfig is designed to be instantiated "
"using the `AutoPolicyConfig.from_pretrained(TODO)` method."
)
raise OSError("AutoPolicyConfig not meant to be instantiated directly")
@classmethod
def for_policy(cls, policy_type: str, *args, **kwargs) -> PreTrainedConfig:
@ -246,13 +248,12 @@ class AutoPolicyConfig:
policy_type (`str`): The policy type like "act" or "pi0".
config ([`PreTrainedConfig`]): The config to register.
"""
# TODO(Steven): config.type doesn't work at this stage because it is not an instance, it the class definition
# if issubclass(config, PreTrainedConfig) and config.type != policy_type:
# raise ValueError(
# "The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
# f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
# "match!"
# )
if issubclass(config, PreTrainedConfig) and config.get_type_str() != policy_type:
raise ValueError(
"The config you are passing has a `policy_type` attribute that is not consistent with the policy type "
f"you passed (config has {config.type} and you passed {policy_type}. Fix one of those so they "
"match!"
)
POLICY_CONFIG_MAPPING.register(policy_type, config, exist_ok=exist_ok)
@classmethod
@ -308,15 +309,12 @@ class AutoPolicy:
"""
def __init__(self):
raise OSError(
"AutoPolicy is designed to be instantiated using the "
"`AutoPolicy.from_config()` or `AutoPolicy.from_pretrained()` methods."
)
raise OSError("AutoPolicy not meant to be instantiated directly")
@classmethod
def from_config(cls, config: PreTrainedConfig, **kwargs) -> PreTrainedPolicy:
"""Instantiate a policy from a configuration."""
policy_class = POLICY_MAPPING.get_policy_for_config(config)
policy_class = POLICY_MAPPING.get_policy_for_config(type(config))
return policy_class(config, **kwargs)
@classmethod
@ -356,16 +354,33 @@ class AutoPolicy:
policy_class: The policy class to register
exist_ok: Whether to allow overwriting existing registrations
"""
POLICY_MAPPING.register(config_class.type, policy_class, config_class, exist_ok=exist_ok)
POLICY_MAPPING.register(config_class.get_type_str(), policy_class, config_class, exist_ok=exist_ok)
def main():
"""Test the AutoPolicy and AutoPolicyConfig functionality."""
# Simulates a build-in policy type being loaded
# Built-in policies work without explicit registration
def test_error_cases():
"""Test error handling"""
try:
AutoPolicyConfig()
except OSError as e:
assert "not meant to be instantiated directly" in str(e)
try:
AutoPolicy()
except OSError as e:
assert "not meant to be instantiated directly" in str(e)
# config = AutoPolicyConfig.for_policy("vqbet")
config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
# try:
# AutoPolicy.from_config("invalid_config")
# except ValueError as e:
# assert "Unrecognized policy identifier" in str(e)
logging.basicConfig(level=logging.DEBUG)
# Test built-in policy loading
# config = AutoPolicyConfig.from_pretrained("lerobot/vqbet_pusht")
config = AutoPolicyConfig.for_policy("vqbet")
policy = AutoPolicy.from_config(config)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
@ -374,8 +389,7 @@ def main():
assert isinstance(config, VQBeTConfig)
assert isinstance(policy, VQBeTPolicy)
# Simulates a new policy type being registered
# Only new policies need registration
# Test policy registration
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@ -387,6 +401,9 @@ def main():
assert isinstance(my_new_config, TDMPCConfig)
assert isinstance(my_new_policy, TDMPCPolicy)
# Run error case tests
test_error_cases()
if __name__ == "__main__":
main()

View File

@ -47,6 +47,15 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
# TODO(Steven): Find a better way to do deal with this
@classmethod
def get_type_str(cls) -> str:
"""Get the policy type identifier for this configuration class."""
class_name = cls.__name__.lower()
if class_name.endswith("config"):
return class_name[:-6] # Remove 'config' suffix
return class_name
@abc.abstractproperty
def observation_delta_indices(self) -> list | None:
raise NotImplementedError