backup wip
This commit is contained in:
parent
e6c6c2367f
commit
381cc70117
|
@ -107,7 +107,6 @@ class TDMPCPolicy(nn.Module):
|
|||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
|
||||
|
|
|
@ -113,11 +113,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
if cfg.online_steps > 0:
|
||||
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.online_steps and cfg.rollout_batch_size == 1:
|
||||
logging.warning("rollout_batch_size > 1 not supported for online training steps")
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue