From 381cc70117cb1c0481274e7d86510b06dd62ea42 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 9 Apr 2024 09:13:58 +0100 Subject: [PATCH] backup wip --- lerobot/common/policies/tdmpc/policy.py | 1 - lerobot/scripts/train.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 942ee9b1..787b0d63 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -107,7 +107,6 @@ class TDMPCPolicy(nn.Module): self.model_target = deepcopy(self.model) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) - # self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.model.eval() self.model_target.eval() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index caaf5182..ce8f3488 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -113,11 +113,12 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() - if cfg.online_steps > 0: - assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps" init_logging() + if cfg.online_steps and cfg.rollout_batch_size == 1: + logging.warning("rollout_batch_size > 1 not supported for online training steps") + # Check device is available get_safe_torch_device(cfg.device, log=True)