backup wip

This commit is contained in:
Alexander Soare 2024-04-09 09:13:58 +01:00
parent e6c6c2367f
commit 381cc70117
2 changed files with 3 additions and 3 deletions

View File

@ -107,7 +107,6 @@ class TDMPCPolicy(nn.Module):
self.model_target = deepcopy(self.model) self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval() self.model.eval()
self.model_target.eval() self.model_target.eval()

View File

@ -113,11 +113,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:
raise NotImplementedError() raise NotImplementedError()
if cfg.online_steps > 0:
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
init_logging() init_logging()
if cfg.online_steps and cfg.rollout_batch_size == 1:
logging.warning("rollout_batch_size > 1 not supported for online training steps")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) get_safe_torch_device(cfg.device, log=True)