revert some formatting changes
This commit is contained in:
parent
c50a13ab31
commit
63e5ec6483
|
@ -20,9 +20,7 @@ from torch import Tensor, nn
|
|||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import (
|
||||
ActionChunkingTransformerConfig,
|
||||
)
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
cfg.replace_final_stride_with_dilation,
|
||||
],
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
|
@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module):
|
|||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
|
Loading…
Reference in New Issue