backup wip
This commit is contained in:
parent
e6c6c2367f
commit
381cc70117
|
@ -107,7 +107,6 @@ class TDMPCPolicy(nn.Module):
|
||||||
self.model_target = deepcopy(self.model)
|
self.model_target = deepcopy(self.model)
|
||||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model_target.eval()
|
self.model_target.eval()
|
||||||
|
|
||||||
|
|
|
@ -113,11 +113,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if cfg.online_steps > 0:
|
|
||||||
assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps"
|
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
if cfg.online_steps and cfg.rollout_batch_size == 1:
|
||||||
|
logging.warning("rollout_batch_size > 1 not supported for online training steps")
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(cfg.device, log=True)
|
get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue