revert some formatting changes
This commit is contained in:
parent
c50a13ab31
commit
63e5ec6483
|
@ -20,9 +20,7 @@ from torch import Tensor, nn
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import (
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
ActionChunkingTransformerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
@ -99,11 +97,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||||
)
|
)
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
replace_stride_with_dilation=[
|
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||||
False,
|
|
||||||
False,
|
|
||||||
cfg.replace_final_stride_with_dilation,
|
|
||||||
],
|
|
||||||
pretrained=cfg.use_pretrained_backbone,
|
pretrained=cfg.use_pretrained_backbone,
|
||||||
norm_layer=FrozenBatchNorm2d,
|
norm_layer=FrozenBatchNorm2d,
|
||||||
)
|
)
|
||||||
|
@ -445,10 +439,7 @@ class _TransformerDecoder(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(
|
x = layer(
|
||||||
x,
|
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||||
encoder_out,
|
|
||||||
decoder_pos_embed=decoder_pos_embed,
|
|
||||||
encoder_pos_embed=encoder_pos_embed,
|
|
||||||
)
|
)
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
Loading…
Reference in New Issue