This commit is contained in:
Cadene 2024-02-18 01:24:19 +00:00
parent a5c305a7a4
commit fdfb2010fd
1 changed files with 1 additions and 4 deletions

View File

@ -128,10 +128,7 @@ class TDMPC(nn.Module):
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {
k: o.detach().unsqueeze(0)
for k, o in obs.items()
}
obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()}
else:
obs = obs.detach().unsqueeze(0)
z = self.model.encode(obs)