diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 4a8df1ce..3aab03cf 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -98,13 +98,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py index e79a94ff..7ce7b881 100644 --- a/tests/scripts/save_policy_to_safetensor.py +++ b/tests/scripts/save_policy_to_safetensor.py @@ -105,6 +105,6 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": # Instructions: include the policies that you want to save artifacts for here. Please make sure to revert # your changes when you are done. - env_policies = [] + env_policies = [("aloha", "act", ["policy.n_action_steps=10"])] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)