139 lines
5.2 KiB
Python
139 lines
5.2 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import sys
|
|
from argparse import ArgumentError
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import draccus
|
|
|
|
from lerobot.common.utils.utils import has_method
|
|
|
|
PATH_KEY = "path"
|
|
draccus.set_config_type("json")
|
|
|
|
|
|
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
|
"""Parses arguments from cli at a given nested attribute level.
|
|
|
|
For example, supposing the main script was called with:
|
|
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
|
|
|
|
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
|
|
["--subarg1=abc" "--subarg2=some/path"]
|
|
"""
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
attr_level_args = []
|
|
detect_string = f"--{field_name}."
|
|
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
|
for arg in args:
|
|
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
|
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
|
attr_level_args.append(denested_arg)
|
|
|
|
return attr_level_args
|
|
|
|
|
|
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
prefix = f"--{arg_name}="
|
|
for arg in args:
|
|
if arg.startswith(prefix):
|
|
return arg[len(prefix) :]
|
|
return None
|
|
|
|
|
|
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
|
|
|
|
|
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
|
|
|
|
|
|
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
|
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
|
|
|
|
|
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
|
"""
|
|
Filters command-line arguments related to fields with specific path arguments.
|
|
|
|
Args:
|
|
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
|
|
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
|
|
Defaults to None.
|
|
|
|
Returns:
|
|
list[str]: A filtered list of arguments, with arguments related to the specified
|
|
fields removed.
|
|
|
|
Raises:
|
|
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
|
|
argument (e.g., `--field_name.type`) are specified for the same field.
|
|
"""
|
|
if isinstance(fields_to_filter, str):
|
|
fields_to_filter = [fields_to_filter]
|
|
|
|
filtered_args = args
|
|
for field in fields_to_filter:
|
|
if get_path_arg(field, args):
|
|
if get_type_arg(field, args):
|
|
raise ArgumentError(
|
|
argument=None,
|
|
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
|
)
|
|
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
|
|
|
return filtered_args
|
|
|
|
|
|
def wrap(config_path: Path | None = None):
|
|
"""
|
|
HACK: Similar to draccus.wrap but does two additional things:
|
|
- Will remove '.path' arguments from CLI in order to process them later on.
|
|
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
|
initialize it from there to allow to fetch configs from the hub directly
|
|
"""
|
|
|
|
def wrapper_outer(fn):
|
|
@wraps(fn)
|
|
def wrapper_inner(*args, **kwargs):
|
|
argspec = inspect.getfullargspec(fn)
|
|
argtype = argspec.annotations[argspec.args[0]]
|
|
if len(args) > 0 and type(args[0]) is argtype:
|
|
cfg = args[0]
|
|
args = args[1:]
|
|
else:
|
|
cli_args = sys.argv[1:]
|
|
config_path_cli = parse_arg("config_path", cli_args)
|
|
if has_method(argtype, "__get_path_fields__"):
|
|
path_fields = argtype.__get_path_fields__()
|
|
cli_args = filter_path_args(path_fields, cli_args)
|
|
if has_method(argtype, "from_pretrained") and config_path_cli:
|
|
cli_args = filter_arg("config_path", cli_args)
|
|
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
|
else:
|
|
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
|
response = fn(cfg, *args, **kwargs)
|
|
return response
|
|
|
|
return wrapper_inner
|
|
|
|
return wrapper_outer
|