diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95f443da..520390e4 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -66,8 +66,12 @@ class ACTConfig: documentation in the policy class). latent_dim: The VAE's latent dimension. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. - use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given - environment step. + temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling + actions for a given time step over multiple policy invocations. Updates are calculated as: + x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different + parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our + formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we + require `n_action_steps == 1` (since we need to query the policy every step anyway). dropout: Dropout to use in the transformer layers (see code for details). kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. @@ -122,7 +126,7 @@ class ACTConfig: n_vae_encoder_layers: int = 4 # Inference. - use_temporal_aggregation: bool = False + temporal_ensemble_momentum: float | None = None # Training and loss computation. dropout: float = 0.1 @@ -134,8 +138,11 @@ class ACTConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - if self.use_temporal_aggregation: - raise NotImplementedError("Temporal aggregation is not yet implemented.") + if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: + raise NotImplementedError( + "`n_action_steps` must be 1 when using temporal ensembling. This is " + "because the policy needs to be queried every step to compute the ensembled action." + ) if self.n_action_steps > self.chunk_size: raise ValueError( f"The chunk size is the upper bound for the number of action steps per model invocation. Got " diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e85a3736..dac1b889 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -61,7 +61,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): super().__init__() if config is None: config = ACTConfig() - self.config = config + self.config: ACTConfig = config self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) @@ -75,7 +75,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def reset(self): """This should be called whenever the environment is reset.""" - if self.config.n_action_steps is not None: + if self.config.temporal_ensemble_momentum is not None: + self._ensembled_actions = None + else: self._action_queue = deque([], maxlen=self.config.n_action_steps) @torch.no_grad @@ -94,14 +96,33 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) self._stack_images(batch) + # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return + # the first action. + if self.config.temporal_ensemble_momentum is not None: + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) + if self._ensembled_actions is None: + self._ensembled_actions = actions.clone() + else: + # self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the EMA update for those entries. + alpha = self.config.temporal_ensemble_momentum + self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1] + # The last action, which has no prior moving average, needs to get concatenated onto the end. + self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1) + # "Consume" the first action. + action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:] + return action + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - actions = self.model(batch)[0][: self.config.n_action_steps] + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 15efcce8..f09e6a12 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -73,7 +73,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - use_temporal_aggregation: false + temporal_ensemble_momentum: null # Training and loss computation. dropout: 0.1