Add type hints

This commit is contained in:
Alexander Soare 2024-04-08 14:51:45 +01:00
parent 86365adf9f
commit 62b18a7607
1 changed files with 9 additions and 7 deletions

View File

@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_):
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
@ -189,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return self._action_queue.popleft()
@torch.no_grad()
def select_actions(self, batch: dict[str, Tensor]):
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
@ -211,7 +211,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return action[: self.n_action_steps]
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> dict:
# TODO(now): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs)
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_):
def update(self, batch, *_) -> dict:
start_time = time.time()
self._preprocess_batch(batch)
@ -277,7 +277,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time
@ -309,7 +309,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
action, _ = self._forward(batch["observation.state"], images)
return action
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
@ -410,7 +412,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
actions = self.action_head(decoder_out)
return actions, [mu, log_sigma_x2]
return actions, (mu, log_sigma_x2)
def save(self, fp):
torch.save(self.state_dict(), fp)