Add type hints
This commit is contained in:
parent
86365adf9f
commit
62b18a7607
|
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], *_):
|
||||
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
|
@ -189,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, batch: dict[str, Tensor]):
|
||||
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
||||
|
@ -211,7 +211,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return action[: self.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
||||
return self.update(*args, **kwargs)
|
||||
|
||||
|
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_):
|
||||
def update(self, batch, *_) -> dict:
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
|
@ -277,7 +277,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
|
@ -309,7 +309,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
action, _ = self._forward(batch["observation.state"], images)
|
||||
return action
|
||||
|
||||
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
||||
def _forward(
|
||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
robot_state: (B, J) batch of robot joint configurations.
|
||||
|
@ -410,7 +412,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
|
||||
actions = self.action_head(decoder_out)
|
||||
|
||||
return actions, [mu, log_sigma_x2]
|
||||
return actions, (mu, log_sigma_x2)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
|
Loading…
Reference in New Issue