Merge pull request #80 from huggingface/alexander-soare/unify_policy_api
Unify policy API
This commit is contained in:
commit
d5c4b0c344
|
@ -54,7 +54,7 @@ done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
info = policy(batch)
|
info = policy.update(batch)
|
||||||
if step % log_freq == 0:
|
if step % log_freq == 0:
|
||||||
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
||||||
step += 1
|
step += 1
|
||||||
|
|
|
@ -103,12 +103,21 @@ class ActionChunkingTransformerConfig:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Input validation (not exhaustive)."""
|
"""Input validation (not exhaustive)."""
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError("`vision_backbone` must be one of the ResNet variants.")
|
raise ValueError(
|
||||||
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
)
|
||||||
if self.use_temporal_aggregation:
|
if self.use_temporal_aggregation:
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The chunk size is the upper bound for the number of action steps per model invocation."
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||||
|
)
|
||||||
|
if self.n_obs_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||||
)
|
)
|
||||||
if self.camera_names != ["top"]:
|
if self.camera_names != ["top"]:
|
||||||
raise ValueError("For now, `camera_names` can only be ['top']")
|
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
|
||||||
|
if len(set(self.camera_names)) != len(self.camera_names):
|
||||||
|
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
|
||||||
|
|
|
@ -61,9 +61,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
_multiple_obs_steps_not_handled_msg = (
|
|
||||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||||
"""
|
"""
|
||||||
|
@ -74,8 +71,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
if cfg.n_obs_steps != 1:
|
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
|
@ -173,73 +168,47 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
|
self.eval()
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
|
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||||
# shape (n_action_steps, batch_size, *), hence the transpose.
|
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad
|
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||||
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||||
self.eval()
|
|
||||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
|
||||||
|
|
||||||
action = self.forward(batch, return_loss=False)
|
l1_loss = (
|
||||||
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
).mean()
|
||||||
|
|
||||||
return action[: self.cfg.n_action_steps]
|
loss_dict = {"l1_loss": l1_loss}
|
||||||
|
if self.cfg.use_vae:
|
||||||
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
|
mean_kld = (
|
||||||
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
|
)
|
||||||
|
loss_dict["kld_loss"] = mean_kld
|
||||||
|
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] = l1_loss
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs) -> dict:
|
return loss_dict
|
||||||
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
|
|
||||||
return self.update(*args, **kwargs)
|
|
||||||
|
|
||||||
def _preprocess_batch(
|
|
||||||
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
|
||||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if add_obs_steps_dim:
|
|
||||||
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
|
||||||
# this just amounts to an unsqueeze.
|
|
||||||
for k in batch:
|
|
||||||
if k.startswith("observation."):
|
|
||||||
batch[k] = batch[k].unsqueeze(1)
|
|
||||||
|
|
||||||
if batch["observation.state"].shape[1] != 1:
|
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
|
||||||
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
|
||||||
assert (
|
|
||||||
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
|
||||||
), "ACT only handles one image for now."
|
|
||||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
|
||||||
# the image index dimension.
|
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
def update(self, batch, **_) -> dict:
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self._preprocess_batch(batch)
|
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
loss_dict = self.forward(batch)
|
||||||
num_slices = self.cfg.batch_size
|
loss = loss_dict["loss"]
|
||||||
batch_size = self.cfg.chunk_size * num_slices
|
|
||||||
|
|
||||||
assert batch_size % self.cfg.chunk_size == 0
|
|
||||||
assert batch_size % num_slices == 0
|
|
||||||
|
|
||||||
loss = self.forward(batch, return_loss=True)["loss"]
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.parameters(),
|
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
self.cfg.grad_clip_norm,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
@ -254,68 +223,64 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||||
images = self.image_normalizer(batch["observation.images.top"])
|
|
||||||
|
|
||||||
if return_loss: # training time
|
This function expects `batch` to have (at least):
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
{
|
||||||
batch["observation.state"], images, batch["action"]
|
"observation.state": (B, state_dim) batch of robot states.
|
||||||
)
|
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||||
|
}
|
||||||
l1_loss = (
|
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
|
||||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
loss_dict = {}
|
|
||||||
loss_dict["l1"] = l1_loss
|
|
||||||
if self.cfg.use_vae:
|
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
|
||||||
mean_kld = (
|
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
|
||||||
)
|
|
||||||
loss_dict["kl"] = mean_kld
|
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
|
||||||
else:
|
|
||||||
loss_dict["loss"] = loss_dict["l1"]
|
|
||||||
return loss_dict
|
|
||||||
else:
|
|
||||||
action, _ = self._forward(batch["observation.state"], images)
|
|
||||||
return action
|
|
||||||
|
|
||||||
def _forward(
|
|
||||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
|
||||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
# Check that there is only one image.
|
||||||
robot_state: (B, J) batch of robot joint configurations.
|
# TODO(alexander-soare): generalize this to multiple images.
|
||||||
image: (B, N, C, H, W) batch of N camera frames.
|
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
|
||||||
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
|
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
|
||||||
VAE is enabled and the model is in training mode.
|
raise ValueError(
|
||||||
|
f"The following camera images are missing from the provided batch: {missing}. Check the "
|
||||||
|
"configuration parameter: `camera_names`."
|
||||||
|
)
|
||||||
|
# Stack images in the order dictated by the camera names.
|
||||||
|
batch["observation.images"] = torch.stack(
|
||||||
|
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
|
||||||
|
dim=-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||||
|
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||||
|
|
||||||
|
`batch` should have the following structure:
|
||||||
|
|
||||||
|
{
|
||||||
|
"observation.state": (B, state_dim) batch of robot states.
|
||||||
|
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||||
|
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||||
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(B, S, A) batch of action sequences
|
(B, chunk_size, action_dim) batch of action sequences
|
||||||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||||
latent dimension.
|
latent dimension.
|
||||||
"""
|
"""
|
||||||
if self.cfg.use_vae and self.training:
|
if self.cfg.use_vae and self.training:
|
||||||
assert (
|
assert (
|
||||||
actions is not None
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = robot_state.shape[0]
|
self._stack_images(batch)
|
||||||
|
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.cfg.use_vae and actions is not None:
|
if self.cfg.use_vae and "action" in batch:
|
||||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(robot_state).unsqueeze(1) # (B, 1, D)
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
||||||
action_embed = self.vae_encoder_action_input_proj(actions) # (B, S, D)
|
1
|
||||||
|
) # (B, 1, D)
|
||||||
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
|
@ -337,15 +302,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
robot_state.device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare all other transformer encoder inputs.
|
# Prepare all other transformer encoder inputs.
|
||||||
# Camera observation features and positional embeddings.
|
# Camera observation features and positional embeddings.
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
for cam_id, _ in enumerate(self.cfg.camera_names):
|
images = self.image_normalizer(batch["observation.images"])
|
||||||
cam_features = self.backbone(image[:, cam_id])["feature_map"]
|
for cam_index in range(len(self.cfg.camera_names)):
|
||||||
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
|
@ -355,7 +321,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(robot_state)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
|
|
@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||||
super().__init__()
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
configuration class is used.
|
configuration class is used.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
# TODO(alexander-soare): LR scheduler will be removed.
|
# TODO(alexander-soare): LR scheduler will be removed.
|
||||||
assert lr_scheduler_num_training_steps > 0
|
assert lr_scheduler_num_training_steps > 0
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
|
@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def forward(self, batch, **_):
|
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
||||||
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""A protocol that all policies should follow.
|
||||||
|
|
||||||
|
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||||
|
subclass a base class.
|
||||||
|
|
||||||
|
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||||
|
how to implement new policies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Policy(Protocol):
|
||||||
|
"""The required interface for implementing a policy."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""To be called whenever the environment is reset.
|
||||||
|
|
||||||
|
Does things like clearing caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||||
|
"""Run the batch through the model and compute the loss for training or validation.
|
||||||
|
|
||||||
|
Returns a dictionary with "loss" and maybe other information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, Tensor]):
|
||||||
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
with caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update(self, batch):
|
||||||
|
"""Does compute_loss then an optimization step.
|
||||||
|
|
||||||
|
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||||
|
disappear.
|
||||||
|
"""
|
|
@ -330,6 +330,10 @@ class TDMPCPolicy(nn.Module):
|
||||||
return td_target
|
return td_target
|
||||||
|
|
||||||
def forward(self, batch, step):
|
def forward(self, batch, step):
|
||||||
|
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def update(self, batch, step):
|
||||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,4 @@ policy:
|
||||||
utd: 1
|
utd: 1
|
||||||
|
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]"
|
|
||||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
|
@ -257,7 +257,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy(batch, step=step)
|
train_info = policy.update(batch, step=step)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
@ -318,7 +318,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy(batch, step)
|
train_info = policy.update(batch, step)
|
||||||
|
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||||
|
|
|
@ -4,11 +4,13 @@ import torch
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name,extra_overrides",
|
"env_name,policy_name,extra_overrides",
|
||||||
[
|
[
|
||||||
|
@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
"""
|
"""
|
||||||
Tests:
|
Tests:
|
||||||
- Making the policy object.
|
- Making the policy object.
|
||||||
|
- Checking that the policy follows the correct protocol.
|
||||||
- Updating the policy.
|
- Updating the policy.
|
||||||
- Using the policy to select actions at inference time.
|
- Using the policy to select actions at inference time.
|
||||||
- Test the action can be applied to the policy
|
- Test the action can be applied to the policy
|
||||||
|
@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
f"policy={policy_name}",
|
f"policy={policy_name}",
|
||||||
f"device={DEVICE}",
|
f"device={DEVICE}",
|
||||||
]
|
]
|
||||||
+ extra_overrides
|
+ extra_overrides,
|
||||||
)
|
)
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
# Check that the policy follows the required protocol.
|
||||||
|
assert isinstance(
|
||||||
|
policy, Policy
|
||||||
|
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
env = make_env(cfg, num_parallel_envs=2)
|
env = make_env(cfg, num_parallel_envs=2)
|
||||||
|
@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||||
|
|
||||||
# Test updating the policy
|
# Test updating the policy
|
||||||
policy(batch, step=0)
|
policy.update(batch, step=0)
|
||||||
|
|
||||||
# reset the policy and environment
|
# reset the policy and environment
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
# Test step through policy
|
# Test step through policy
|
||||||
env.step(action)
|
env.step(action)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue