lerobot/lerobot/configs/parser.py

139 lines
5.2 KiB
Python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import sys
from argparse import ArgumentError
from functools import wraps
from pathlib import Path
from typing import Sequence
import draccus
from lerobot.common.utils.utils import has_method
PATH_KEY = "path"
draccus.set_config_type("json")
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
For example, supposing the main script was called with:
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
If called during execution of myscript.py, get_cli_overrides("arg2") will return:
["--subarg1=abc" "--subarg2=some/path"]
"""
if args is None:
args = sys.argv[1:]
attr_level_args = []
detect_string = f"--{field_name}."
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
for arg in args:
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
denested_arg = f"--{arg.removeprefix(detect_string)}"
attr_level_args.append(denested_arg)
return attr_level_args
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
if args is None:
args = sys.argv[1:]
prefix = f"--{arg_name}="
for arg in args:
if arg.startswith(prefix):
return arg[len(prefix) :]
return None
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
Args:
fields_to_filter (str | list[str]): A single str or a list of str whose arguments need to be filtered.
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
Defaults to None.
Returns:
list[str]: A filtered list of arguments, with arguments related to the specified
fields removed.
Raises:
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a type
argument (e.g., `--field_name.type`) are specified for the same field.
"""
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
filtered_args = args
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
raise ArgumentError(
argument=None,
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
)
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
return filtered_args
def wrap(config_path: Path | None = None):
"""
HACK: Similar to draccus.wrap but does two additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
initialize it from there to allow to fetch configs from the hub directly
"""
def wrapper_outer(fn):
@wraps(fn)
def wrapper_inner(*args, **kwargs):
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
cfg = args[0]
args = args[1:]
else:
cli_args = sys.argv[1:]
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
if has_method(argtype, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
response = fn(cfg, *args, **kwargs)
return response
return wrapper_inner
return wrapper_outer