From c0101f094805b03bb679731516eb5bcbf053b678 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Jul 2024 10:27:21 +0100 Subject: [PATCH] Fix ACT temporal ensembling (#319) --- .../advanced/1_train_act_pusht/act_pusht.yaml | 2 +- .../common/policies/act/configuration_act.py | 15 ++- lerobot/common/policies/act/modeling_act.py | 118 +++++++++++++++--- lerobot/configs/policy/act.yaml | 2 +- lerobot/configs/policy/act_real.yaml | 2 +- lerobot/configs/policy/act_real_no_state.yaml | 2 +- tests/test_policies.py | 63 +++++++++- 7 files changed, 173 insertions(+), 31 deletions(-) diff --git a/examples/advanced/1_train_act_pusht/act_pusht.yaml b/examples/advanced/1_train_act_pusht/act_pusht.yaml index 38e542fb..4963e11c 100644 --- a/examples/advanced/1_train_act_pusht/act_pusht.yaml +++ b/examples/advanced/1_train_act_pusht/act_pusht.yaml @@ -80,7 +80,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 92a52eac..a86c359c 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -76,12 +76,10 @@ class ACTConfig: documentation in the policy class). latent_dim: The VAE's latent dimension. n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. - temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling - actions for a given time step over multiple policy invocations. Updates are calculated as: - x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different - parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our - formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we - require `n_action_steps == 1` (since we need to query the policy every step anyway). + temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal + ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be + 1 when using this feature, as inference needs to happen at every step to form an ensemble. For + more information on how ensembling works, please see `ACTTemporalEnsembler`. dropout: Dropout to use in the transformer layers (see code for details). kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. @@ -139,7 +137,8 @@ class ACTConfig: n_vae_encoder_layers: int = 4 # Inference. - temporal_ensemble_momentum: float | None = None + # Note: the value used in ACT when temporal ensembling is enabled is 0.01. + temporal_ensemble_coeff: float | None = None # Training and loss computation. dropout: float = 0.1 @@ -151,7 +150,7 @@ class ACTConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: + if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1: raise NotImplementedError( "`n_action_steps` must be 1 when using temporal ensembling. This is " "because the policy needs to be queried every step to compute the ensembled action." diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 0a236100..c072c31e 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + self.reset() def reset(self): """This should be called whenever the environment is reset.""" - if self.config.temporal_ensemble_momentum is not None: - self._ensembled_actions = None + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() else: self._action_queue = deque([], maxlen=self.config.n_action_steps) @@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if len(self.expected_image_keys) > 0: batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) - # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return - # the first action. - if self.config.temporal_ensemble_momentum is not None: + # If we are doing temporal ensembling, do online updates where we keep track of the number of actions + # we are ensembling over. + if self.config.temporal_ensemble_coeff is not None: actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] - if self._ensembled_actions is None: - # Initializes `self._ensembled_action` to the sequence of actions predicted during the first - # time step of the episode. - self._ensembled_actions = actions.clone() - else: - # self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute - # the EMA update for those entries. - alpha = self.config.temporal_ensemble_momentum - self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1] - # The last action, which has no prior moving average, needs to get concatenated onto the end. - self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1) - # "Consume" the first action. - action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:] + action = self.temporal_ensembler.update(actions) return action # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by @@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): return loss_dict +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat( + [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + ) + # "Consume" the first action. + action, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, 0], + self.ensembled_actions[:, 1:], + self.ensembled_actions_count[1:], + ) + return action + + class ACT(nn.Module): """Action Chunking Transformer: The underlying neural network for ACTPolicy. diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index ea2c5b75..28883936 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -75,7 +75,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml index c2f7158f..058104f4 100644 --- a/lerobot/configs/policy/act_real.yaml +++ b/lerobot/configs/policy/act_real.yaml @@ -107,7 +107,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/lerobot/configs/policy/act_real_no_state.yaml b/lerobot/configs/policy/act_real_no_state.yaml index 5b8a13b4..08261050 100644 --- a/lerobot/configs/policy/act_real_no_state.yaml +++ b/lerobot/configs/policy/act_real_no_state.yaml @@ -103,7 +103,7 @@ policy: n_vae_encoder_layers: 4 # Inference. - temporal_ensemble_momentum: null + temporal_ensemble_coeff: null # Training and loss computation. dropout: 0.1 diff --git a/tests/test_policies.py b/tests/test_policies.py index bc9c34ff..63f394e9 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -16,6 +16,7 @@ import inspect from pathlib import Path +import einops import pytest import torch from huggingface_hub import PyTorchModelHubMixin @@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.common.policies.factory import ( _policy_cfg_from_hydra_cfg, get_policy_and_config_classes, @@ -33,7 +35,7 @@ from lerobot.common.policies.factory import ( ) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.utils.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all() for key in saved_actions: assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all() + + +def test_act_temporal_ensembler(): + """Check that the online method in ACTTemporalEnsembler matches a simple offline calculation.""" + temporal_ensemble_coeff = 0.01 + chunk_size = 100 + episode_length = 101 + ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size) + # An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the + # "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen. + with seeded_context(0): + # Dimension is (batch, episode_length, chunk_size, action_dim(=1)) + # Stepping through the episode_length dim is like running inference at each rollout step and getting + # a different action chunk. + batch_seq = torch.stack( + [ + torch.rand(episode_length, chunk_size) * 0.05 - 0.6, + torch.rand(episode_length, chunk_size) * 0.02 - 0.01, + torch.rand(episode_length, chunk_size) * 0.2 + 0.3, + ], + dim=0, + ).unsqueeze(-1) # unsqueeze for action dim + batch_size = batch_seq.shape[0] + # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` + # dimension of `batch_seq`. + weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) + + # Simulate stepping through a rollout and computing a batch of actions with model on each step. + for i in range(episode_length): + # Mock a batch of actions. + actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i] + online_avg = ensembler.update(actions) + # Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ). + # Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid. + # What we want to do is take diagonal slices across it starting from the left. + # eg: chunk_size=4, episode_length=6 + # ┌───────┐ + # │0 1 2 3│ + # │1 2 3 4│ + # │2 3 4 5│ + # │3 4 5 6│ + # │4 5 6 7│ + # │5 6 7 8│ + # └───────┘ + chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1) + episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] + seq_slice = batch_seq[:, episode_step_indices, chunk_indices] + offline_avg = ( + einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() + ) + # Sanity check. The average should be between the extrema. + assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) + assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) + # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. + assert torch.allclose(online_avg, offline_avg, atol=1e-4) + + +if __name__ == "__main__": + test_act_temporal_ensembler()