[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
035e95a41b
commit
78c05cf0be
|
@ -68,6 +68,7 @@ class DOT(nn.Module):
|
||||||
│ Outputs │
|
│ Outputs │
|
||||||
└──────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: DOTConfig):
|
def __init__(self, config: DOTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -196,18 +197,22 @@ class DOT(nn.Module):
|
||||||
Tensor: A tensor of shape (B, horizon, action_dim) containing predicted future actions.
|
Tensor: A tensor of shape (B, horizon, action_dim) containing predicted future actions.
|
||||||
"""
|
"""
|
||||||
# Project image/state/env_state inputs to the model dimension and concatenate along the time axis.
|
# Project image/state/env_state inputs to the model dimension and concatenate along the time axis.
|
||||||
inputs_projections = self._process_inputs(batch) # (B, T, D)
|
inputs_projections = self._process_inputs(batch) # (B, T, D)
|
||||||
batch_size = inputs_projections.shape[0]
|
batch_size = inputs_projections.shape[0]
|
||||||
|
|
||||||
# Add learnable positional embeddings to each projected input token.
|
# Add learnable positional embeddings to each projected input token.
|
||||||
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
|
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
|
||||||
|
|
||||||
# Prepend a trainable prefix token to the input sequence
|
# Prepend a trainable prefix token to the input sequence
|
||||||
inputs_projections = torch.cat([self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1) # (B, T+1, D)
|
inputs_projections = torch.cat(
|
||||||
|
[self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1
|
||||||
|
) # (B, T+1, D)
|
||||||
|
|
||||||
# Use different positional encodings and masks for training vs. inference.
|
# Use different positional encodings and masks for training vs. inference.
|
||||||
if self.training:
|
if self.training:
|
||||||
decoder_out = self.decoder(self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask)
|
decoder_out = self.decoder(
|
||||||
|
self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
|
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
|
||||||
return self.action_head(decoder_out)
|
return self.action_head(decoder_out)
|
||||||
|
@ -307,7 +312,7 @@ class DOTPolicy(PreTrainedPolicy):
|
||||||
return self.model.parameters()
|
return self.model.parameters()
|
||||||
|
|
||||||
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
|
def _update_observation_buffers(self, buffer_name: str, observation: Tensor) -> Tensor:
|
||||||
# Maintain a rolling buffer of lookback_obs_steps + 1;
|
# Maintain a rolling buffer of lookback_obs_steps + 1;
|
||||||
# shift left and append new observation each step
|
# shift left and append new observation each step
|
||||||
if buffer_name not in self._input_buffers:
|
if buffer_name not in self._input_buffers:
|
||||||
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
|
self._input_buffers[buffer_name] = observation.unsqueeze(1).repeat(
|
||||||
|
@ -472,6 +477,7 @@ class LoRAConv2d(nn.Module):
|
||||||
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
|
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
|
||||||
rank (int): The rank of the low-rank approximation (default: 4).
|
rank (int): The rank of the low-rank approximation (default: 4).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
|
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_conv = base_conv
|
self.base_conv = base_conv
|
||||||
|
|
Loading…
Reference in New Issue