diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 808a3145..4819ca80 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -65,8 +65,9 @@ def make_policy( if pretrained_policy_name_or_path is None: policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) policy = policy_cls(policy_cfg, dataset_stats) - policy.to(get_safe_torch_device(hydra_cfg.device)) else: policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + policy.to(get_safe_torch_device(hydra_cfg.device)) + return policy