revision
This commit is contained in:
parent
0a721f3d94
commit
86365adf9f
|
@ -32,12 +32,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
elif cfg.env.name == "aloha":
|
elif cfg.env.name == "aloha":
|
||||||
import gym_aloha # noqa: F401
|
import gym_aloha # noqa: F401
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
if cfg.env.task == "sim_transfer_cube":
|
||||||
|
env_name = "gym_aloha/AlohaTransferCube-v0"
|
||||||
|
elif cfg.env.task == "sim_insertion":
|
||||||
|
env_name = "gym_aloha/AlohaInsertion-v0"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`{cfg.env.task}` has no environment implementation.")
|
||||||
|
|
||||||
env_fn = lambda: gym.make( # noqa: E731
|
env_fn = lambda: gym.make(env_name, **kwargs) # noqa: E731
|
||||||
"gym_aloha/AlohaTransferCube-v0",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
|
|
|
@ -337,18 +337,21 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||||
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
|
||||||
|
# Forward pass through VAE encoder.
|
||||||
cls_token_out = self.vae_encoder(
|
cls_token_out = self.vae_encoder(
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||||
)[0] # (B, D)
|
)[0] # select the class token, with shape (B, D)
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
|
|
||||||
|
# Sample the latent with the reparameterization trick.
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
mu = latent_pdf_params[:, : self.latent_dim]
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||||
# Use reparameterization trick to sample from the latent's PDF.
|
|
||||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
|
@ -469,7 +472,7 @@ class _TransformerEncoderLayer(nn.Module):
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
q = k = x if pos_embed is None else x + pos_embed
|
q = k = x if pos_embed is None else x + pos_embed
|
||||||
x = self.self_attn(q, k, value=x)[0]
|
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||||
x = skip + self.dropout1(x)
|
x = skip + self.dropout1(x)
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
skip = x
|
skip = x
|
||||||
|
@ -563,7 +566,7 @@ class _TransformerDecoderLayer(nn.Module):
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||||
x = self.self_attn(q, k, value=x)[0]
|
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||||
x = skip + self.dropout1(x)
|
x = skip + self.dropout1(x)
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
skip = x
|
skip = x
|
||||||
|
@ -575,7 +578,7 @@ class _TransformerDecoderLayer(nn.Module):
|
||||||
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
||||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
||||||
value=encoder_out,
|
value=encoder_out,
|
||||||
)[0]
|
)[0] # select just the output, not the attention weights
|
||||||
x = skip + self.dropout2(x)
|
x = skip + self.dropout2(x)
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
skip = x
|
skip = x
|
||||||
|
@ -634,7 +637,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
||||||
"""
|
"""
|
||||||
not_mask = torch.ones_like(x[0, [0]]) # (1, H, W)
|
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||||
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
|
# they would be range(0, H) and range(0, W). Keeping it at as to match the original code.
|
||||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
|
|
@ -64,4 +64,4 @@ policy:
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.images.top: [0.0]
|
observation.images.top: [0.0]
|
||||||
observation.state: [0.0]
|
observation.state: [0.0]
|
||||||
action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98]
|
action: "[i / ${fps} for i in range(${horizon})]"
|
||||||
|
|
|
@ -152,7 +152,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
policy.save("act.pt")
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
@ -198,7 +197,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
is_offline = True
|
is_offline = True
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
batch_size=cfg.policy.batch_size,
|
batch_size=cfg.policy.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
|
Loading…
Reference in New Issue