fix unit tests

This commit is contained in:
Remi Cadene 2024-10-15 18:51:55 +02:00
parent cb30d7a8bf
commit 19d410a372
2 changed files with 6 additions and 4 deletions

View File

@ -167,6 +167,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
use_amp = hydra_cfg.use_amp
policy_fps = hydra_cfg.env.fps
policy.eval()
policy.to(device)
@ -174,8 +175,6 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
policy_fps = hydra_cfg.env.fps
return policy, policy_fps, device, use_amp

View File

@ -223,9 +223,12 @@ def record(
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
if fps != policy_fps:
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
elif fps != policy_fps:
logging.warning(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}."
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes