Fix missing `policy.to(device)` in policy factory (#126)
This commit is contained in:
parent
d1855a202a
commit
c1668924ab
|
@ -65,8 +65,9 @@ def make_policy(
|
||||||
if pretrained_policy_name_or_path is None:
|
if pretrained_policy_name_or_path is None:
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||||
policy = policy_cls(policy_cfg, dataset_stats)
|
policy = policy_cls(policy_cfg, dataset_stats)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
|
||||||
else:
|
else:
|
||||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||||
|
|
||||||
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
Loading…
Reference in New Issue