From c1668924ab9e3d7f7a15e41e1c604aa2355b9175 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 May 2024 17:26:58 +0100 Subject: [PATCH] Fix missing `policy.to(device)` in policy factory (#126) --- lerobot/common/policies/factory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 808a3145..4819ca80 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -65,8 +65,9 @@ def make_policy( if pretrained_policy_name_or_path is None: policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) policy = policy_cls(policy_cfg, dataset_stats) - policy.to(get_safe_torch_device(hydra_cfg.device)) else: policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) + policy.to(get_safe_torch_device(hydra_cfg.device)) + return policy