Do not override fps
This commit is contained in:
parent
3960125b07
commit
cb30d7a8bf
|
@ -158,7 +158,8 @@ def init_keyboard_listener():
|
||||||
return listener, events
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
|
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||||
|
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||||
|
@ -174,14 +175,8 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_global_seed(hydra_cfg.seed)
|
set_global_seed(hydra_cfg.seed)
|
||||||
|
|
||||||
# override fps using policy fps
|
|
||||||
policy_fps = hydra_cfg.env.fps
|
policy_fps = hydra_cfg.env.fps
|
||||||
|
return policy, policy_fps, device, use_amp
|
||||||
if fps != policy_fps:
|
|
||||||
logging.warning(f"Overrides fps from provided one {fps} to the one from policy config {policy_fps}")
|
|
||||||
fps = policy_fps
|
|
||||||
|
|
||||||
return policy, fps, device, use_amp
|
|
||||||
|
|
||||||
|
|
||||||
def warmup_record(
|
def warmup_record(
|
||||||
|
|
|
@ -99,6 +99,7 @@ python lerobot/scripts/control_robot.py record \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -220,7 +221,12 @@ def record(
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides, fps)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
|
|
||||||
|
if fps != policy_fps:
|
||||||
|
logging.warning(
|
||||||
|
f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}."
|
||||||
|
)
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
|
Loading…
Reference in New Issue