Fix act action queue (#185)
This commit is contained in:
parent
c9069df9f1
commit
4d7d41cdee
|
@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue