Add type hints
This commit is contained in:
parent
86365adf9f
commit
62b18a7607
|
@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if self.n_action_steps is not None:
|
if self.n_action_steps is not None:
|
||||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor], *_):
|
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
|
@ -189,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_actions(self, batch: dict[str, Tensor]):
|
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||||
self.eval()
|
self.eval()
|
||||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
||||||
|
@ -211,7 +211,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return action[: self.n_action_steps]
|
return action[: self.n_action_steps]
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs) -> dict:
|
||||||
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
||||||
return self.update(*args, **kwargs)
|
return self.update(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||||
# the image index dimension.
|
# the image index dimension.
|
||||||
|
|
||||||
def update(self, batch, *_):
|
def update(self, batch, *_) -> dict:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self._preprocess_batch(batch)
|
self._preprocess_batch(batch)
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
|
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
||||||
images = self.image_normalizer(batch["observation.images.top"])
|
images = self.image_normalizer(batch["observation.images.top"])
|
||||||
|
|
||||||
if return_loss: # training time
|
if return_loss: # training time
|
||||||
|
@ -309,7 +309,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
action, _ = self._forward(batch["observation.state"], images)
|
action, _ = self._forward(batch["observation.state"], images)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
|
def _forward(
|
||||||
|
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||||
|
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
robot_state: (B, J) batch of robot joint configurations.
|
robot_state: (B, J) batch of robot joint configurations.
|
||||||
|
@ -410,7 +412,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
actions = self.action_head(decoder_out)
|
actions = self.action_head(decoder_out)
|
||||||
|
|
||||||
return actions, [mu, log_sigma_x2]
|
return actions, (mu, log_sigma_x2)
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
Loading…
Reference in New Issue