Do not override fps

This commit is contained in:
Remi Cadene 2024-10-15 18:27:22 +02:00
parent 3960125b07
commit cb30d7a8bf
2 changed files with 10 additions and 9 deletions

View File

@ -158,7 +158,8 @@ def init_keyboard_listener():
return listener, events
def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
@ -174,14 +175,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
# override fps using policy fps
policy_fps = hydra_cfg.env.fps
if fps != policy_fps:
logging.warning(f"Overrides fps from provided one {fps} to the one from policy config {policy_fps}")
fps = policy_fps
return policy, fps, device, use_amp
return policy, policy_fps, device, use_amp
def warmup_record(

View File

@ -99,6 +99,7 @@ python lerobot/scripts/control_robot.py record \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List
@ -220,7 +221,12 @@ def record(
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
if fps != policy_fps:
logging.warning(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)