Merge pull request #80 from huggingface/alexander-soare/unify_policy_api

Unify policy API
This commit is contained in:
Alexander Soare 2024-04-17 17:05:22 +01:00 committed by GitHub
commit d5c4b0c344
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 156 additions and 122 deletions

View File

@ -54,7 +54,7 @@ done = False
while not done: while not done:
for batch in dataloader: for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy(batch) info = policy.update(batch)
if step % log_freq == 0: if step % log_freq == 0:
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)") print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
step += 1 step += 1

View File

@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"): if not self.vision_backbone.startswith("resnet"):
raise ValueError("`vision_backbone` must be one of the ResNet variants.") raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation: if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.") raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
"The chunk size is the upper bound for the number of action steps per model invocation." f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
if self.camera_names != ["top"]: if self.camera_names != ["top"]:
raise ValueError("For now, `camera_names` can only be ['top']") raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")

View File

@ -61,9 +61,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
""" """
name = "act" name = "act"
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
""" """
@ -74,8 +71,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
super().__init__() super().__init__()
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
@ -173,73 +168,47 @@ class ActionChunkingTransformerPolicy(nn.Module):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
self.eval()
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has # `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# shape (n_action_steps, batch_size, *), hence the transpose. # has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
@torch.no_grad def forward(self, batch, **_) -> dict[str, Tensor]:
def select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Run the batch through the model and compute the loss for training or validation."""
"""Use the action chunking transformer to generate a sequence of actions.""" actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
self.eval()
self._preprocess_batch(batch, add_obs_steps_dim=True)
action = self.forward(batch, return_loss=False) l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
return action[: self.cfg.n_action_steps] loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss
def __call__(self, *args, **kwargs) -> dict: return loss_dict
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
return self.update(*args, **kwargs)
def _preprocess_batch(
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
) -> dict[str, Tensor]:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
}
"""
if add_obs_steps_dim:
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
# this just amounts to an unsqueeze.
for k in batch:
if k.startswith("observation."):
batch[k] = batch[k].unsqueeze(1)
if batch["observation.state"].shape[1] != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1)
# TODO(alexander-soare): generalize this to multiple images.
assert (
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
), "ACT only handles one image for now."
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, **_) -> dict: def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step.""" """Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self._preprocess_batch(batch)
self.train() self.train()
loss_dict = self.forward(batch)
num_slices = self.cfg.batch_size loss = loss_dict["loss"]
batch_size = self.cfg.chunk_size * num_slices
assert batch_size % self.cfg.chunk_size == 0
assert batch_size % num_slices == 0
loss = self.forward(batch, return_loss=True)["loss"]
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
) )
self.optimizer.step() self.optimizer.step()
@ -254,68 +223,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""A forward pass through the DNN part of this policy with optional loss computation.""" """Stacks all the images in a batch and puts them in a new key: "observation.images".
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time This function expects `batch` to have (at least):
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( {
batch["observation.state"], images, batch["action"] "observation.state": (B, state_dim) batch of robot states.
) "observation.images.{name}": (B, C, H, W) tensor of images.
}
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {}
loss_dict["l1"] = l1_loss
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self._forward(batch["observation.state"], images)
return action
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
""" """
Args: # Check that there is only one image.
robot_state: (B, J) batch of robot joint configurations. # TODO(alexander-soare): generalize this to multiple images.
image: (B, N, C, H, W) batch of N camera frames. provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
actions: (B, S, A) batch of actions from the target dataset which must be provided if the if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
VAE is enabled and the model is in training mode. raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
dim=-4,
)
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns: Returns:
(B, S, A) batch of action sequences (B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension. latent dimension.
""" """
if self.cfg.use_vae and self.training: if self.cfg.use_vae and self.training:
assert ( assert (
actions is not None "action" in batch
), "actions must be provided when using the variational objective in training mode." ), "actions must be provided when using the variational objective in training mode."
batch_size = robot_state.shape[0] self._stack_images(batch)
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and actions is not None: if self.cfg.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat( cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D) 1
) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Prepare fixed positional embedding. # Prepare fixed positional embedding.
@ -337,15 +302,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
robot_state.device batch["observation.state"].device
) )
# Prepare all other transformer encoder inputs. # Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
for cam_id, _ in enumerate(self.cfg.camera_names): images = self.image_normalizer(batch["observation.images"])
cam_features = self.backbone(image[:, cam_id])["feature_map"] for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
@ -355,7 +321,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(robot_state) robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample) latent_embed = self.encoder_latent_input_proj(latent_sample)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).

View File

@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__()
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used. configuration class is used.
""" """
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed. # TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0 assert lr_scheduler_num_training_steps > 0
if cfg is None: if cfg is None:
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
return action return action
def forward(self, batch, **_): def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time() start_time = time.time()
self.diffusion.train() self.diffusion.train()
loss = self.diffusion.compute_loss(batch) loss = self.forward(batch)["loss"]
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -0,0 +1,45 @@
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy."""
name: str
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
"""
def select_action(self, batch: dict[str, Tensor]):
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def update(self, batch):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
disappear.
"""

View File

@ -330,6 +330,10 @@ class TDMPCPolicy(nn.Module):
return td_target return td_target
def forward(self, batch, step): def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.time()

View File

@ -67,6 +67,4 @@ policy:
utd: 1 utd: 1
delta_timestamps: delta_timestamps:
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(${policy.chunk_size})]" action: "[i / ${fps} for i in range(${policy.chunk_size})]"

View File

@ -257,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step=step) train_info = policy.update(batch, step=step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
@ -318,7 +318,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step) train_info = policy.update(batch, step)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)

View File

@ -4,11 +4,13 @@ import torch
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name,policy_name,extra_overrides", "env_name,policy_name,extra_overrides",
[ [
@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
""" """
Tests: Tests:
- Making the policy object. - Making the policy object.
- Checking that the policy follows the correct protocol.
- Updating the policy. - Updating the policy.
- Using the policy to select actions at inference time. - Using the policy to select actions at inference time.
- Test the action can be applied to the policy - Test the action can be applied to the policy
@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
f"policy={policy_name}", f"policy={policy_name}",
f"device={DEVICE}", f"device={DEVICE}",
] ]
+ extra_overrides + extra_overrides,
) )
# Check that we can make the policy object. # Check that we can make the policy object.
policy = make_policy(cfg) policy = make_policy(cfg)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2) env = make_env(cfg, num_parallel_envs=2)
@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True) batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy # Test updating the policy
policy(batch, step=0) policy.update(batch, step=0)
# reset the policy and environment # reset the policy and environment
policy.reset() policy.reset()
@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy # Test step through policy
env.step(action) env.step(action)