fix unit tests
This commit is contained in:
parent
cb30d7a8bf
commit
19d410a372
|
@ -167,6 +167,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
|||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
use_amp = hydra_cfg.use_amp
|
||||
policy_fps = hydra_cfg.env.fps
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
@ -174,8 +175,6 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
|||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
policy_fps = hydra_cfg.env.fps
|
||||
return policy, policy_fps, device, use_amp
|
||||
|
||||
|
||||
|
|
|
@ -223,9 +223,12 @@ def record(
|
|||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
||||
if fps != policy_fps:
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
elif fps != policy_fps:
|
||||
logging.warning(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config {policy_fps}."
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
|
|
Loading…
Reference in New Issue