From 625f0557ef9d2fc6b57604dafb99820d89d1d02e Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 17 May 2024 14:57:49 +0100 Subject: [PATCH] Act temporal ensembling (#186) --- .../common/policies/act/configuration_act.py | 17 +++++++---- lerobot/common/policies/act/modeling_act.py | 28 +++++++++++++++++-- lerobot/configs/policy/act.yaml | 2 +- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index be444b06..cc072083 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -66,8 +66,12 @@ class ACTConfig: documentation in the policy class). latent_dim: The VAE's latent dimension. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. - use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given - environment step. + temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling + actions for a given time step over multiple policy invocations. Updates are calculated as: + x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different + parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our + formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we + require `n_action_steps == 1` (since we need to query the policy every step anyway). dropout: Dropout to use in the transformer layers (see code for details). kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. @@ -122,7 +126,7 @@ class ACTConfig: n_vae_encoder_layers: int = 4 # Inference. - use_temporal_aggregation: bool = False + temporal_ensemble_momentum: float | None = None # Training and loss computation. dropout: float = 0.1 @@ -134,8 +138,11 @@ class ACTConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - if self.use_temporal_aggregation: - raise NotImplementedError("Temporal aggregation is not yet implemented.") + if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: + raise NotImplementedError( + "`n_action_steps` must be 1 when using temporal ensembling. This is " + "because the policy needs to be queried every step to compute the ensembled action." + ) if self.n_action_steps > self.chunk_size: raise ValueError( f"The chunk size is the upper bound for the number of action steps per model invocation. Got " diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3aab03cf..72ebdd7a 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -61,7 +61,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): super().__init__() if config is None: config = ACTConfig() - self.config = config + self.config: ACTConfig = config self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats @@ -81,7 +81,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def reset(self): """This should be called whenever the environment is reset.""" - if self.config.n_action_steps is not None: + if self.config.temporal_ensemble_momentum is not None: + self._ensembled_actions = None + else: self._action_queue = deque([], maxlen=self.config.n_action_steps) @torch.no_grad @@ -97,6 +99,28 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return + # the first action. + if self.config.temporal_ensemble_momentum is not None: + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + if self._ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self._ensembled_actions = actions.clone() + else: + # self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the EMA update for those entries. + alpha = self.config.temporal_ensemble_momentum + self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1] + # The last action, which has no prior moving average, needs to get concatenated onto the end. + self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1) + # "Consume" the first action. + action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:] + return action + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. if len(self._action_queue) == 0: actions = self.model(batch)[0][:, : self.config.n_action_steps] diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 15efcce8..f09e6a12 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -73,7 +73,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - use_temporal_aggregation: false + temporal_ensemble_momentum: null # Training and loss computation. dropout: 0.1