# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import sys from argparse import ArgumentError from functools import wraps from pathlib import Path from typing import Sequence import draccus from lerobot.common.utils.utils import has_method PATH_KEY = "path" draccus.set_config_type("json") def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: """Parses arguments from cli at a given nested attribute level. For example, supposing the main script was called with: python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path If called during execution of myscript.py, get_cli_overrides("arg2") will return: ["--subarg1=abc" "--subarg2=some/path"] """ if args is None: args = sys.argv[1:] attr_level_args = [] detect_string = f"--{field_name}." exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=") for arg in args: if arg.startswith(detect_string) and not arg.startswith(exclude_strings): denested_arg = f"--{arg.removeprefix(detect_string)}" attr_level_args.append(denested_arg) return attr_level_args def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None: if args is None: args = sys.argv[1:] prefix = f"--{arg_name}=" for arg in args: if arg.startswith(prefix): return arg[len(prefix) :] return None def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: return parse_arg(f"{field_name}.{PATH_KEY}", args) def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args) def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]: return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]: """ Filters command-line arguments related to fields with specific path arguments. Args: fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered. args (Sequence[str] | None): The sequence of command-line arguments to be filtered. Defaults to None. Returns: list[str]: A filtered list of arguments, with arguments related to the specified fields removed. Raises: ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type argument (e.g., `--field_name.type`) are specified for the same field. """ if isinstance(fields_to_filter, str): fields_to_filter = [fields_to_filter] filtered_args = args for field in fields_to_filter: if get_path_arg(field, args): if get_type_arg(field, args): raise ArgumentError( argument=None, message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}", ) filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")] return filtered_args def wrap(config_path: Path | None = None): """ HACK: Similar to draccus.wrap but does two additional things: - Will remove '.path' arguments from CLI in order to process them later on. - If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will initialize it from there to allow to fetch configs from the hub directly """ def wrapper_outer(fn): @wraps(fn) def wrapper_inner(*args, **kwargs): argspec = inspect.getfullargspec(fn) argtype = argspec.annotations[argspec.args[0]] if len(args) > 0 and type(args[0]) is argtype: cfg = args[0] args = args[1:] else: cli_args = sys.argv[1:] config_path_cli = parse_arg("config_path", cli_args) if has_method(argtype, "__get_path_fields__"): path_fields = argtype.__get_path_fields__() cli_args = filter_path_args(path_fields, cli_args) if has_method(argtype, "from_pretrained") and config_path_cli: cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) else: cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) response = fn(cfg, *args, **kwargs) return response return wrapper_inner return wrapper_outer