Act temporal ensembling (#186)
This commit is contained in:
parent
4d7d41cdee
commit
625f0557ef
|
@ -66,8 +66,12 @@ class ACTConfig:
|
||||||
documentation in the policy class).
|
documentation in the policy class).
|
||||||
latent_dim: The VAE's latent dimension.
|
latent_dim: The VAE's latent dimension.
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||||
environment step.
|
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||||
|
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||||
|
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||||
|
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||||
|
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
dropout: Dropout to use in the transformer layers (see code for details).
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||||
|
@ -122,7 +126,7 @@ class ACTConfig:
|
||||||
n_vae_encoder_layers: int = 4
|
n_vae_encoder_layers: int = 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: bool = False
|
temporal_ensemble_momentum: float | None = None
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
|
@ -134,8 +138,11 @@ class ACTConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
if self.use_temporal_aggregation:
|
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
raise NotImplementedError(
|
||||||
|
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||||
|
"because the policy needs to be queried every step to compute the ensembled action."
|
||||||
|
)
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
|
|
@ -61,7 +61,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config is None:
|
if config is None:
|
||||||
config = ACTConfig()
|
config = ACTConfig()
|
||||||
self.config = config
|
self.config: ACTConfig = config
|
||||||
|
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||||
|
@ -81,7 +81,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
if self.config.n_action_steps is not None:
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
self._ensembled_actions = None
|
||||||
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
|
@ -97,6 +99,28 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
|
||||||
|
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||||
|
# the first action.
|
||||||
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
if self._ensembled_actions is None:
|
||||||
|
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||||
|
# time step of the episode.
|
||||||
|
self._ensembled_actions = actions.clone()
|
||||||
|
else:
|
||||||
|
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||||
|
# the EMA update for those entries.
|
||||||
|
alpha = self.config.temporal_ensemble_momentum
|
||||||
|
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||||
|
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||||
|
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||||
|
# "Consume" the first action.
|
||||||
|
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||||
|
return action
|
||||||
|
|
||||||
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
|
# querying the policy.
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ policy:
|
||||||
n_vae_encoder_layers: 4
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
# Inference.
|
# Inference.
|
||||||
use_temporal_aggregation: false
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
# Training and loss computation.
|
# Training and loss computation.
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
|
|
Loading…
Reference in New Issue