Merge pull request #80 from huggingface/alexander-soare/unify_policy_api

Unify policy API
This commit is contained in:
Alexander Soare 2024-04-17 17:05:22 +01:00 committed by GitHub
commit d5c4b0c344
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 156 additions and 122 deletions

View File

@ -54,7 +54,7 @@ done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy(batch)
info = policy.update(batch)
if step % log_freq == 0:
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
step += 1

View File

@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
"The chunk size is the upper bound for the number of action steps per model invocation."
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError("For now, `camera_names` can only be ['top']")
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")

View File

@ -61,9 +61,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""
name = "act"
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
@ -74,8 +71,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
@ -173,73 +168,47 @@ class ActionChunkingTransformerPolicy(nn.Module):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
# shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
action = self.forward(batch, return_loss=False)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
return action[: self.cfg.n_action_steps]
loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss
def __call__(self, *args, **kwargs) -> dict:
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs)
def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
}
"""
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
if batch["observation.state"].shape[1] != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1)
# TODO(alexander-soare): generalize this to multiple images.
assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), "ACT only handles one image for now."
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
return loss_dict
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self._preprocess_batch(batch)
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.chunk_size * num_slices
assert batch_size % self.cfg.chunk_size == 0
assert batch_size % num_slices == 0
loss = self.forward(batch, return_loss=True)["loss"]
loss_dict = self.forward(batch)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optimizer.step()
@ -254,68 +223,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
images = self.image_normalizer(batch["observation.images.top"])
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"], images, batch["action"]
)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {}
loss_dict["l1"] = l1_loss
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self._forward(batch["observation.state"], images)
return action
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
dim=-4,
)
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
(B, S, A) batch of action sequences
(B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.cfg.use_vae and self.training:
assert (
actions is not None
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = robot_state.shape[0]
self._stack_images(batch)
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and actions is not None:
if self.cfg.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Prepare fixed positional embedding.
@ -337,15 +302,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
robot_state.device
batch["observation.state"].device
)
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
for cam_id, _ in enumerate(self.cfg.camera_names):
cam_features = self.backbone(image[:, cam_id])["feature_map"]
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@ -355,7 +321,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample)
# Stack encoder input and positional embeddings moving to (S, B, C).

View File

@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__()
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
action = self._queues["action"].popleft()
return action
def forward(self, batch, **_):
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss = self.forward(batch)["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -0,0 +1,45 @@
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy."""
name: str
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
"""
def select_action(self, batch: dict[str, Tensor]):
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def update(self, batch):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
disappear.
"""

View File

@ -330,6 +330,10 @@ class TDMPCPolicy(nn.Module):
return td_target
def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()

View File

@ -67,6 +67,4 @@ policy:
utd: 1
delta_timestamps:
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(${policy.chunk_size})]"

View File

@ -257,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step=step)
train_info = policy.update(batch, step=step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@ -318,7 +318,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
train_info = policy.update(batch, step)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)

View File

@ -4,11 +4,13 @@ import torch
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy(batch, step=0)
policy.update(batch, step=0)
# reset the policy and environment
policy.reset()
@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)