Act temporal ensembling (#186)

This commit is contained in:
Alexander Soare 2024-05-17 14:57:49 +01:00 committed by GitHub
parent 4d7d41cdee
commit 625f0557ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 8 deletions

View File

@ -66,8 +66,12 @@ class ACTConfig:
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
environment step.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
actions for a given time step over multiple policy invocations. Updates are calculated as:
x = αx + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-mi) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@ -122,7 +126,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4
# Inference.
use_temporal_aggregation: bool = False
temporal_ensemble_momentum: float | None = None
# Training and loss computation.
dropout: float = 0.1
@ -134,8 +138,11 @@ class ACTConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "

View File

@ -61,7 +61,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
super().__init__()
if config is None:
config = ACTConfig()
self.config = config
self.config: ACTConfig = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
@ -81,7 +81,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.n_action_steps is not None:
if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
@ -97,6 +99,28 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]

View File

@ -73,7 +73,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
use_temporal_aggregation: false
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1