[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-04 10:04:32 +00:00
parent 035e95a41b
commit 78c05cf0be
1 changed files with 10 additions and 4 deletions

View File

@ -68,6 +68,7 @@ class DOT(nn.Module):
Outputs
"""
def __init__(self, config: DOTConfig):
super().__init__()
self.config = config
@ -203,11 +204,15 @@ class DOT(nn.Module):
inputs_projections += self.inputs_pos_emb.expand(batch_size, -1, -1)
# Prepend a trainable prefix token to the input sequence
inputs_projections = torch.cat([self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1) # (B, T+1, D)
inputs_projections = torch.cat(
[self.prefix_input.expand(batch_size, -1, -1), inputs_projections], dim=1
) # (B, T+1, D)
# Use different positional encodings and masks for training vs. inference.
if self.training:
decoder_out = self.decoder(self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask)
decoder_out = self.decoder(
self.decoder_pos.expand(batch_size, -1, -1), inputs_projections, self.mask
)
else:
decoder_out = self.decoder(self.decoder_pos_inf.expand(batch_size, -1, -1), inputs_projections)
return self.action_head(decoder_out)
@ -472,6 +477,7 @@ class LoRAConv2d(nn.Module):
base_conv (nn.Conv2d): The original convolutional layer to be adapted.
rank (int): The rank of the low-rank approximation (default: 4).
"""
def __init__(self, base_conv: nn.Conv2d, rank: int = 4):
super().__init__()
self.base_conv = base_conv