diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py index 476a9b40..39e31515 100644 --- a/lerobot/configs/parser.py +++ b/lerobot/configs/parser.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect +import pkgutil import sys from argparse import ArgumentError from functools import wraps @@ -23,6 +25,7 @@ import draccus from lerobot.common.utils.utils import has_method PATH_KEY = "path" +PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" draccus.set_config_type("json") @@ -58,6 +61,86 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None: return None +def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict: + """Parse plugin-related arguments from command-line arguments. + + This function extracts arguments from command-line arguments that match a specified suffix pattern. + It processes arguments in the format '--key=value' and returns them as a dictionary. + + Args: + plugin_arg_suffix (str): The suffix to identify plugin-related arguments. + cli_args (Sequence[str]): A sequence of command-line arguments to parse. + + Returns: + dict: A dictionary containing the parsed plugin arguments where: + - Keys are the argument names (with '--' prefix removed if present) + - Values are the corresponding argument values + + Example: + >>> args = ['--env.discover_packages_path=my_package', + ... '--other_arg=value'] + >>> parse_plugin_args('discover_packages_path', args) + {'env.discover_packages_path': 'my_package'} + """ + plugin_args = {} + for arg in args: + if "=" in arg and plugin_arg_suffix in arg: + key, value = arg.split("=", 1) + # Remove leading '--' if present + if key.startswith("--"): + key = key[2:] + plugin_args[key] = value + return plugin_args + + +class PluginLoadError(Exception): + """Raised when a plugin fails to load.""" + + +def load_plugin(plugin_path: str) -> None: + """Load and initialize a plugin from a given Python package path. + + This function attempts to load a plugin by importing its package and any submodules. + Plugin registration is expected to happen during package initialization, i.e. when + the package is imported the gym environment should be registered and the config classes + registered with their parents using the `register_subclass` decorator. + + Args: + plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin") + + Raises: + PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid. + + Examples: + >>> load_plugin("external_plugin.core") # Loads plugin from external package + + Notes: + - The plugin package should handle its own registration during import + - All submodules in the plugin package will be imported + - Implementation follows the plugin discovery pattern from Python packaging guidelines + + See Also: + https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ + """ + try: + package_module = importlib.import_module(plugin_path, __package__) + except (ImportError, ModuleNotFoundError) as e: + raise PluginLoadError( + f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" + ) from e + + def iter_namespace(ns_pkg): + return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".") + + try: + for _finder, pkg_name, _ispkg in iter_namespace(package_module): + importlib.import_module(pkg_name) + except ImportError as e: + raise PluginLoadError( + f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}" + ) from e + + def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: return parse_arg(f"{field_name}.{PATH_KEY}", args) @@ -105,10 +188,13 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No def wrap(config_path: Path | None = None): """ - HACK: Similar to draccus.wrap but does two additional things: + HACK: Similar to draccus.wrap but does three additional things: - Will remove '.path' arguments from CLI in order to process them later on. - If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will initialize it from there to allow to fetch configs from the hub directly + - Will load plugins specified in the CLI arguments. These plugins will typically register + their own subclasses of config classes, so that draccus can find the right class to instantiate + from the CLI '.type' arguments """ def wrapper_outer(fn): @@ -121,6 +207,14 @@ def wrap(config_path: Path | None = None): args = args[1:] else: cli_args = sys.argv[1:] + plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args) + for plugin_cli_arg, plugin_path in plugin_args.items(): + try: + load_plugin(plugin_path) + except PluginLoadError as e: + # add the relevant CLI arg to the error message + raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e + cli_args = filter_arg(plugin_cli_arg, cli_args) config_path_cli = parse_arg("config_path", cli_args) if has_method(argtype, "__get_path_fields__"): path_fields = argtype.__get_path_fields__() diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py new file mode 100644 index 00000000..1a8cceed --- /dev/null +++ b/tests/configs/test_plugin_loading.py @@ -0,0 +1,89 @@ +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Generator + +import pytest + +from lerobot.common.envs.configs import EnvConfig +from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap + + +def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str: + """Creates a dummy plugin module that implements its own EnvConfig subclass.""" + return f""" +from dataclasses import dataclass +from lerobot.common.envs.configs import {base_class} + +@{base_class}.register_subclass("{plugin_name}") +@dataclass +class TestPluginConfig: + value: int = 42 + """ + + +@pytest.fixture +def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]: + """Creates a temporary plugin package structure.""" + plugin_pkg = tmp_path / "test_plugin" + plugin_pkg.mkdir() + (plugin_pkg / "__init__.py").touch() + + with open(plugin_pkg / "my_plugin.py", "w") as f: + f.write(create_plugin_code()) + + # Add tmp_path to Python path so we can import from it + sys.path.insert(0, str(tmp_path)) + yield plugin_pkg + sys.path.pop(0) + + +def test_parse_plugin_args(): + cli_args = [ + "--env.type=test", + "--model.discover_packages_path=some.package", + "--env.discover_packages_path=other.package", + ] + plugin_args = parse_plugin_args("discover_packages_path", cli_args) + assert plugin_args == { + "model.discover_packages_path": "some.package", + "env.discover_packages_path": "other.package", + } + + +def test_load_plugin_success(plugin_dir: Path): + # Import should work and register the plugin with the real EnvConfig + load_plugin("test_plugin") + + assert "test_env" in EnvConfig.get_known_choices() + plugin_cls = EnvConfig.get_choice_class("test_env") + plugin_instance = plugin_cls() + assert plugin_instance.value == 42 + + +def test_load_plugin_failure(): + with pytest.raises(PluginLoadError) as exc_info: + load_plugin("nonexistent_plugin") + assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value) + + +def test_wrap_with_plugin(plugin_dir: Path): + @dataclass + class Config: + env: EnvConfig + + @wrap() + def dummy_func(cfg: Config): + return cfg + + # Test loading plugin via CLI args + sys.argv = [ + "dummy_script.py", + "--env.discover_packages_path=test_plugin", + "--env.type=test_env", + ] + + cfg = dummy_func() + assert isinstance(cfg, Config) + assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env")) + assert cfg.env.value == 42