revert some formatting changes

This commit is contained in:
Alexander Soare 2024-04-17 11:01:01 +01:00
parent c50a13ab31
commit 63e5ec6483
1 changed files with 3 additions and 12 deletions

View File

@ -20,9 +20,7 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ( from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
ActionChunkingTransformerConfig,
)
class ActionChunkingTransformerPolicy(nn.Module): class ActionChunkingTransformerPolicy(nn.Module):
@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
) )
backbone_model = getattr(torchvision.models, cfg.vision_backbone)( backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[ replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
False,
False,
cfg.replace_final_stride_with_dilation,
],
pretrained=cfg.use_pretrained_backbone, pretrained=cfg.use_pretrained_backbone,
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module):
) -> Tensor: ) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer( x = layer(
x, x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
) )
if self.norm is not None: if self.norm is not None:
x = self.norm(x) x = self.norm(x)