diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 826ca5de..aeed3320 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -158,7 +158,8 @@ def init_keyboard_listener(): return listener, events -def init_policy(pretrained_policy_name_or_path, policy_overrides, fps): +def init_policy(pretrained_policy_name_or_path, policy_overrides): + """Instantiate the policy and load fps, device and use_amp from config yaml""" pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) @@ -174,14 +175,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps): torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(hydra_cfg.seed) - # override fps using policy fps policy_fps = hydra_cfg.env.fps - - if fps != policy_fps: - logging.warning(f"Overrides fps from provided one {fps} to the one from policy config {policy_fps}") - fps = policy_fps - - return policy, fps, device, use_amp + return policy, policy_fps, device, use_amp def warmup_record( diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 59c16eef..f61ce0ec 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -99,6 +99,7 @@ python lerobot/scripts/control_robot.py record \ """ import argparse +import logging import time from pathlib import Path from typing import List @@ -220,7 +221,12 @@ def record( # Load pretrained policy if pretrained_policy_name_or_path is not None: - policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps) + policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + + if fps != policy_fps: + logging.warning( + f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}." + ) # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy)