From 63e5ec64837ae765289d08b3766300105a31198a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 17 Apr 2024 11:01:01 +0100 Subject: [PATCH] revert some formatting changes --- lerobot/common/policies/act/modeling_act.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index af8566c7..6b9b4e0a 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,9 +20,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.policies.act.configuration_act import ( - ActionChunkingTransformerConfig, -) +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig class ActionChunkingTransformerPolicy(nn.Module): @@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module): mean=cfg.image_normalization_mean, std=cfg.image_normalization_std ) backbone_model = getattr(torchvision.models, cfg.vision_backbone)( - replace_stride_with_dilation=[ - False, - False, - cfg.replace_final_stride_with_dilation, - ], + replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], pretrained=cfg.use_pretrained_backbone, norm_layer=FrozenBatchNorm2d, ) @@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module): ) -> Tensor: for layer in self.layers: x = layer( - x, - encoder_out, - decoder_pos_embed=decoder_pos_embed, - encoder_pos_embed=encoder_pos_embed, + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed ) if self.norm is not None: x = self.norm(x)