diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index aeed3320..e4ed73f7 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -167,6 +167,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides): # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) use_amp = hydra_cfg.use_amp + policy_fps = hydra_cfg.env.fps policy.eval() policy.to(device) @@ -174,8 +175,6 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(hydra_cfg.seed) - - policy_fps = hydra_cfg.env.fps return policy, policy_fps, device, use_amp diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index f61ce0ec..3fdef1ff 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -223,9 +223,12 @@ def record( if pretrained_policy_name_or_path is not None: policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) - if fps != policy_fps: + if fps is None: + fps = policy_fps + logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + elif fps != policy_fps: logging.warning( - f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}." + f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})." ) # Create empty dataset or load existing saved episodes