Fix ACT temporal ensembling (#319)
This commit is contained in:
parent
5e54e39795
commit
c0101f0948
|
@ -80,7 +80,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -76,12 +76,10 @@ class ACTConfig:
|
||||||
documentation in the policy class).
|
documentation in the policy class).
|
||||||
latent_dim: The VAE's latent dimension.
|
latent_dim: The VAE's latent dimension.
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
||||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
||||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
||||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
||||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
|
||||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
|
@ -139,7 +137,8 @@ class ACTConfig:
|
||||||
n_vae_encoder_layers: int = 4
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: float | None = None
|
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
||||||
|
temporal_ensemble_coeff: float | None = None
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
|
@ -151,7 +150,7 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||||
"because the policy needs to be queried every step to compute the ensembled action."
|
"because the policy needs to be queried every step to compute the ensembled action."
|
||||||
|
|
|
@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
|
||||||
|
if config.temporal_ensemble_coeff is not None:
|
||||||
|
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
self._ensembled_actions = None
|
self.temporal_ensembler.reset()
|
||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if len(self.expected_image_keys) > 0:
|
if len(self.expected_image_keys) > 0:
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
# the first action.
|
# we are ensembling over.
|
||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
if self._ensembled_actions is None:
|
action = self.temporal_ensembler.update(actions)
|
||||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
|
||||||
# time step of the episode.
|
|
||||||
self._ensembled_actions = actions.clone()
|
|
||||||
else:
|
|
||||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
|
||||||
# the EMA update for those entries.
|
|
||||||
alpha = self.config.temporal_ensemble_momentum
|
|
||||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
|
||||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
|
||||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
|
||||||
# "Consume" the first action.
|
|
||||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ACTTemporalEnsembler:
|
||||||
|
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||||
|
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||||
|
|
||||||
|
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||||
|
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||||
|
coefficient works:
|
||||||
|
- Setting it to 0 uniformly weighs all actions.
|
||||||
|
- Setting it positive gives more weight to older actions.
|
||||||
|
- Setting it negative gives more weight to newer actions.
|
||||||
|
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
||||||
|
results in older actions being weighed more highly than newer actions (the experiments documented in
|
||||||
|
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
||||||
|
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
||||||
|
|
||||||
|
Here we use an online method for computing the average rather than caching a history of actions in
|
||||||
|
order to compute the average offline. For a simple 1D sequence it looks something like:
|
||||||
|
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
|
||||||
|
seq = torch.linspace(8, 8.5, 100)
|
||||||
|
print(seq)
|
||||||
|
|
||||||
|
m = 0.01
|
||||||
|
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
||||||
|
print(exp_weights)
|
||||||
|
|
||||||
|
# Calculate offline
|
||||||
|
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
||||||
|
print("offline", avg)
|
||||||
|
|
||||||
|
# Calculate online
|
||||||
|
for i, item in enumerate(seq):
|
||||||
|
if i == 0:
|
||||||
|
avg = item
|
||||||
|
continue
|
||||||
|
avg *= exp_weights[:i].sum()
|
||||||
|
avg += item * exp_weights[i]
|
||||||
|
avg /= exp_weights[:i+1].sum()
|
||||||
|
print("online", avg)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||||
|
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the online computation variables."""
|
||||||
|
self.ensembled_actions = None
|
||||||
|
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
||||||
|
self.ensembled_actions_count = None
|
||||||
|
|
||||||
|
def update(self, actions: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
||||||
|
time steps, and pop/return the next batch of actions in the sequence.
|
||||||
|
"""
|
||||||
|
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||||
|
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||||
|
if self.ensembled_actions is None:
|
||||||
|
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||||
|
# time step of the episode.
|
||||||
|
self.ensembled_actions = actions.clone()
|
||||||
|
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||||
|
# operations later.
|
||||||
|
self.ensembled_actions_count = torch.ones(
|
||||||
|
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
|
# the online update for those entries.
|
||||||
|
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||||
|
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||||
|
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||||
|
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||||
|
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||||
|
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
|
self.ensembled_actions_count = torch.cat(
|
||||||
|
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||||
|
)
|
||||||
|
# "Consume" the first action.
|
||||||
|
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||||
|
self.ensembled_actions[:, 0],
|
||||||
|
self.ensembled_actions[:, 1:],
|
||||||
|
self.ensembled_actions_count[1:],
|
||||||
|
)
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
class ACT(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -107,7 +107,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -103,7 +103,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
temporal_ensemble_momentum: null
|
temporal_ensemble_coeff: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||||
from lerobot.common.policies.factory import (
|
from lerobot.common.policies.factory import (
|
||||||
_policy_cfg_from_hydra_cfg,
|
_policy_cfg_from_hydra_cfg,
|
||||||
get_policy_and_config_classes,
|
get_policy_and_config_classes,
|
||||||
|
@ -33,7 +35,7 @@ from lerobot.common.policies.factory import (
|
||||||
)
|
)
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||||
for key in saved_actions:
|
for key in saved_actions:
|
||||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_act_temporal_ensembler():
|
||||||
|
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
|
||||||
|
temporal_ensemble_coeff = 0.01
|
||||||
|
chunk_size = 100
|
||||||
|
episode_length = 101
|
||||||
|
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
|
||||||
|
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
|
||||||
|
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
|
||||||
|
with seeded_context(0):
|
||||||
|
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
|
||||||
|
# Stepping through the episode_length dim is like running inference at each rollout step and getting
|
||||||
|
# a different action chunk.
|
||||||
|
batch_seq = torch.stack(
|
||||||
|
[
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
|
||||||
|
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).unsqueeze(-1) # unsqueeze for action dim
|
||||||
|
batch_size = batch_seq.shape[0]
|
||||||
|
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||||
|
# dimension of `batch_seq`.
|
||||||
|
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||||
|
for i in range(episode_length):
|
||||||
|
# Mock a batch of actions.
|
||||||
|
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
|
||||||
|
online_avg = ensembler.update(actions)
|
||||||
|
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
|
||||||
|
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
|
||||||
|
# What we want to do is take diagonal slices across it starting from the left.
|
||||||
|
# eg: chunk_size=4, episode_length=6
|
||||||
|
# ┌───────┐
|
||||||
|
# │0 1 2 3│
|
||||||
|
# │1 2 3 4│
|
||||||
|
# │2 3 4 5│
|
||||||
|
# │3 4 5 6│
|
||||||
|
# │4 5 6 7│
|
||||||
|
# │5 6 7 8│
|
||||||
|
# └───────┘
|
||||||
|
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
|
||||||
|
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||||
|
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||||
|
offline_avg = (
|
||||||
|
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||||
|
)
|
||||||
|
# Sanity check. The average should be between the extrema.
|
||||||
|
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||||
|
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||||
|
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||||
|
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_act_temporal_ensembler()
|
||||||
|
|
Loading…
Reference in New Issue