fix control_robot overrides as options

This commit is contained in:
Remi Cadene 2024-07-27 13:32:35 +02:00
parent 57af6b90a2
commit a5e2571881
1 changed files with 7 additions and 5 deletions

View File

@ -638,7 +638,8 @@ if __name__ == "__main__":
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
) )
base_parser.add_argument( base_parser.add_argument(
"robot_overrides", "--robot-overrides",
type=str,
nargs="*", nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)", help="Any key=value arguments to override config values (use dots for.nested=overrides)",
) )
@ -717,7 +718,8 @@ if __name__ == "__main__":
), ),
) )
parser_record.add_argument( parser_record.add_argument(
"overrides", "--policy-overrides",
type=str,
nargs="*", nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)", help="Any key=value arguments to override config values (use dots for.nested=overrides)",
) )
@ -760,14 +762,14 @@ if __name__ == "__main__":
elif control_mode == "record": elif control_mode == "record":
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
overrides = args.overrides policy_overrides = args.policy_overrides
del kwargs["pretrained_policy_name_or_path"] del kwargs["pretrained_policy_name_or_path"]
del kwargs["overrides"] del kwargs["policy_overrides"]
policy_cfg = None policy_cfg = None
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides) policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path) policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
record(robot, policy, policy_cfg, **kwargs) record(robot, policy, policy_cfg, **kwargs)
else: else: