fix(act): n_vae_encoder_layers config parameter wasn't being used (#400)
This commit is contained in:
parent
c0da806232
commit
b2896d38f5
|
@ -296,7 +296,7 @@ class ACT(nn.Module):
|
||||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
if self.use_robot_state:
|
if self.use_robot_state:
|
||||||
|
@ -521,9 +521,11 @@ class ACT(nn.Module):
|
||||||
class ACTEncoder(nn.Module):
|
class ACTEncoder(nn.Module):
|
||||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||||
|
|
||||||
def __init__(self, config: ACTConfig):
|
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
|
self.is_vae_encoder = is_vae_encoder
|
||||||
|
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||||
|
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
Loading…
Reference in New Issue