revert some formatting changes

This commit is contained in:
Alexander Soare 2024-04-17 11:01:01 +01:00
parent c50a13ab31
commit 63e5ec6483
1 changed files with 3 additions and 12 deletions

View File

@ -20,9 +20,7 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import (
ActionChunkingTransformerConfig,
)
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActionChunkingTransformerPolicy(nn.Module):
@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[
False,
False,
cfg.replace_final_stride_with_dilation,
],
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
norm_layer=FrozenBatchNorm2d,
)
@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)