Improve documentation on VAE encoder inputs (#215)
This commit is contained in:
parent
0b51a335bc
commit
57fb5fe8a6
|
@ -198,7 +198,7 @@ class ACT(nn.Module):
|
||||||
def __init__(self, config: ACTConfig):
|
def __init__(self, config: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
|
@ -214,7 +214,7 @@ class ACT(nn.Module):
|
||||||
self.latent_dim = config.latent_dim
|
self.latent_dim = config.latent_dim
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
|
|
Loading…
Reference in New Issue