Fix missing `policy.to(device)` in policy factory ()

This commit is contained in:
Alexander Soare 2024-05-01 17:26:58 +01:00 committed by GitHub
parent d1855a202a
commit c1668924ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions
lerobot/common/policies

View File

@ -65,8 +65,9 @@ def make_policy(
if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
policy = policy_cls(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
policy.to(get_safe_torch_device(hydra_cfg.device))
return policy