Fix missing `policy.to(device)` in policy factory (#126)
This commit is contained in:
parent
d1855a202a
commit
c1668924ab
lerobot/common/policies
|
@ -65,8 +65,9 @@ def make_policy(
|
|||
if pretrained_policy_name_or_path is None:
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
return policy
|
||||
|
|
Loading…
Reference in New Issue