diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 749bb533..bcbdb95d 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -32,12 +32,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: elif cfg.env.name == "aloha": import gym_aloha # noqa: F401 - kwargs["task"] = cfg.env.task + if cfg.env.task == "sim_transfer_cube": + env_name = "gym_aloha/AlohaTransferCube-v0" + elif cfg.env.task == "sim_insertion": + env_name = "gym_aloha/AlohaInsertion-v0" + else: + raise ValueError(f"`{cfg.env.task}` has no environment implementation.") - env_fn = lambda: gym.make( # noqa: E731 - "gym_aloha/AlohaTransferCube-v0", - **kwargs, - ) + env_fn = lambda: gym.make(env_name, **kwargs) # noqa: E731 else: raise ValueError(cfg.env.name) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 834dd9b2..7fb03576 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -337,18 +337,21 @@ class ActionChunkingTransformerPolicy(nn.Module): robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) - # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) - # Forward pass through VAE encoder and sample the latent with the reparameterization trick. + + # Forward pass through VAE encoder. cls_token_out = self.vae_encoder( vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) - )[0] # (B, D) + )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + + # Sample the latent with the reparameterization trick. mu = latent_pdf_params[:, : self.latent_dim] # This is 2log(sigma). Done this way to match the original implementation. log_sigma_x2 = latent_pdf_params[:, self.latent_dim :] - # Use reparameterization trick to sample from the latent's PDF. latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) else: # When not using the VAE encoder, we set the latent to be all zeros. @@ -469,7 +472,7 @@ class _TransformerEncoderLayer(nn.Module): if self.normalize_before: x = self.norm1(x) q = k = x if pos_embed is None else x + pos_embed - x = self.self_attn(q, k, value=x)[0] + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) if self.normalize_before: skip = x @@ -563,7 +566,7 @@ class _TransformerDecoderLayer(nn.Module): if self.normalize_before: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[0] + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) if self.normalize_before: skip = x @@ -575,7 +578,7 @@ class _TransformerDecoderLayer(nn.Module): query=self.maybe_add_pos_embed(x, decoder_pos_embed), key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), value=encoder_out, - )[0] + )[0] # select just the output, not the attention weights x = skip + self.dropout2(x) if self.normalize_before: skip = x @@ -634,7 +637,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module): Returns: A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. """ - not_mask = torch.ones_like(x[0, [0]]) # (1, H, W) + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations # they would be range(0, H) and range(0, W). Keeping it at as to match the original code. y_range = not_mask.cumsum(1, dtype=torch.float32) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index cd34d115..79729a02 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -64,4 +64,4 @@ policy: delta_timestamps: observation.images.top: [0.0] observation.state: [0.0] - action: [0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, 0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, 1.40, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, 1.8, 1.82, 1.84, 1.86, 1.88, 1.90, 1.92, 1.94, 1.96, 1.98] + action: "[i / ${fps} for i in range(${horizon})]" diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index d49dfff8..caaf5182 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -152,7 +152,6 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_policy") policy = make_policy(cfg) - policy.save("act.pt") num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -198,7 +197,7 @@ def train(cfg: dict, out_dir=None, job_name=None): is_offline = True dataloader = torch.utils.data.DataLoader( dataset, - num_workers=0, + num_workers=4, batch_size=cfg.policy.batch_size, shuffle=True, pin_memory=cfg.device != "cpu",