From 62b18a7607d955eed60ba7eff70b71162f5acaf2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 8 Apr 2024 14:51:45 +0100 Subject: [PATCH] Add type hints --- lerobot/common/policies/act/policy.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 7fb03576..e14a1e88 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.n_action_steps is not None: self._action_queue = deque([], maxlen=self.n_action_steps) - def select_action(self, batch: dict[str, Tensor], *_): + def select_action(self, batch: dict[str, Tensor], *_) -> Tensor: """ This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the @@ -189,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return self._action_queue.popleft() @torch.no_grad() - def select_actions(self, batch: dict[str, Tensor]): + def select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() self._preprocess_batch(batch, add_obs_steps_dim=True) @@ -211,7 +211,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return action[: self.n_action_steps] - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> dict: # TODO(now): Temporary bridge until we know what to do about the `update` method. return self.update(*args, **kwargs) @@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get # the image index dimension. - def update(self, batch, *_): + def update(self, batch, *_) -> dict: start_time = time.time() self._preprocess_batch(batch) @@ -277,7 +277,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return info - def forward(self, batch: dict[str, Tensor], return_loss: bool = False): + def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: images = self.image_normalizer(batch["observation.images.top"]) if return_loss: # training time @@ -309,7 +309,9 @@ class ActionChunkingTransformerPolicy(nn.Module): action, _ = self._forward(batch["observation.state"], images) return action - def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None): + def _forward( + self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: """ Args: robot_state: (B, J) batch of robot joint configurations. @@ -410,7 +412,7 @@ class ActionChunkingTransformerPolicy(nn.Module): actions = self.action_head(decoder_out) - return actions, [mu, log_sigma_x2] + return actions, (mu, log_sigma_x2) def save(self, fp): torch.save(self.state_dict(), fp)