Select from the action dimension instead of the batch dimension
This commit is contained in:
parent
ca1c184cb1
commit
92cd9d33f1
|
@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self.model(batch)[0][: self.config.n_action_steps]
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
|
|
|
@ -105,6 +105,6 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
|||
if __name__ == "__main__":
|
||||
# Instructions: include the policies that you want to save artifacts for here. Please make sure to revert
|
||||
# your changes when you are done.
|
||||
env_policies = []
|
||||
env_policies = [("aloha", "act", ["policy.n_action_steps=10"])]
|
||||
for env, policy, extra_overrides in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||
|
|
Loading…
Reference in New Issue