backup wip
This commit is contained in:
parent
0eb899de73
commit
cb3978b5f3
|
@ -165,8 +165,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if self.cfg.n_action_steps is not None:
|
if self.cfg.n_action_steps is not None:
|
||||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||||
|
|
||||||
# def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
# return self.select_action(self, batch)
|
return self.select_action(self, batch)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||||
|
@ -186,20 +186,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||||
self.eval()
|
self.eval()
|
||||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
batch = self._reshape_batch(batch, add_obs_steps_dim=True)
|
||||||
action = self.forward(batch, return_loss=False)
|
actions, _ = self._forward(
|
||||||
return action[: self.cfg.n_action_steps]
|
batch["observation.state"], self.image_normalizer(batch["observation.images.top"])
|
||||||
|
)
|
||||||
|
return actions[: self.cfg.n_action_steps]
|
||||||
|
|
||||||
|
def _reshape_batch(self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False) -> dict[str, Tensor]:
|
||||||
|
"""Reshapes the batch items to account for various requirements of this policy.
|
||||||
|
|
||||||
def _preprocess_batch(
|
|
||||||
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""
|
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have (at least):
|
||||||
{
|
{
|
||||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
||||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TODO(alexander-soare): Right now this method does and undoes reshaping operations. This is just to
|
||||||
|
separate out the core logic from the temporary logic. See comments below.
|
||||||
"""
|
"""
|
||||||
|
# Create a shallow copy.
|
||||||
|
batch = dict(batch)
|
||||||
|
|
||||||
|
# Add a dimension for observation steps.
|
||||||
if add_obs_steps_dim:
|
if add_obs_steps_dim:
|
||||||
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
||||||
# this just amounts to an unsqueeze.
|
# this just amounts to an unsqueeze.
|
||||||
|
@ -207,18 +215,29 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
if k.startswith("observation."):
|
if k.startswith("observation."):
|
||||||
batch[k] = batch[k].unsqueeze(1)
|
batch[k] = batch[k].unsqueeze(1)
|
||||||
|
|
||||||
if batch["observation.state"].shape[1] != 1:
|
# Temporary logic to remove the observation step dimension as the policy does not yet handle it.
|
||||||
|
# TODO(alexander-soare): generalize this to multiple observations steps.
|
||||||
|
# Check that there is only 1 observation step (policy does not yet handle more).
|
||||||
|
if not all(batch[k].shape[1] == 1 for k in batch if k.startswith("observation.")):
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||||
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
# Remove observation steps dimension.
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
for k in batch:
|
||||||
|
if k.startswith("observation."):
|
||||||
|
batch[k] = batch[k].squeeze(1)
|
||||||
|
|
||||||
|
# Temporary logic to add the multiple image dimension back in.
|
||||||
|
# TODO(alexander-soare): generalize this to multiple images. Once resolved, this logic will stack all
|
||||||
|
# images.
|
||||||
assert (
|
assert (
|
||||||
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
||||||
), "ACT only handles one image for now."
|
), f"{self.__class__.__name__} only handles one image for now."
|
||||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
# Since we only handle one image, just unsqueeze instead of stacking.
|
||||||
# the image index dimension.
|
batch["observation.images.top"] = batch["observation.images.top"].unsqueeze(1)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
def compute_loss(self, batch, **_) -> float:
|
def compute_loss(self, batch, **_) -> float:
|
||||||
self._preprocess_batch(batch)
|
batch = self._reshape_batch(batch)
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
@ -228,7 +247,28 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
assert batch_size % self.cfg.chunk_size == 0
|
assert batch_size % self.cfg.chunk_size == 0
|
||||||
assert batch_size % num_slices == 0
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
loss = self.forward(batch, return_loss=True)["loss"]
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
||||||
|
batch["observation.state"],
|
||||||
|
self.image_normalizer(batch["observation.images.top"]),
|
||||||
|
batch["action"],
|
||||||
|
)
|
||||||
|
|
||||||
|
l1_loss = (
|
||||||
|
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
if self.cfg.use_vae:
|
||||||
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
|
mean_kld = (
|
||||||
|
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||||
|
)
|
||||||
|
loss = l1_loss + mean_kld * self.cfg.kl_weight
|
||||||
|
else:
|
||||||
|
loss = l1_loss
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
def update(self, batch, **_) -> dict:
|
||||||
|
@ -255,39 +295,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
|
||||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
|
||||||
images = self.image_normalizer(batch["observation.images.top"])
|
|
||||||
|
|
||||||
if return_loss: # training time
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
|
|
||||||
batch["observation.state"], images, batch["action"]
|
|
||||||
)
|
|
||||||
|
|
||||||
l1_loss = (
|
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
|
||||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
loss_dict = {}
|
|
||||||
loss_dict["l1"] = l1_loss
|
|
||||||
if self.cfg.use_vae:
|
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
|
||||||
mean_kld = (
|
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
|
||||||
)
|
|
||||||
loss_dict["kl"] = mean_kld
|
|
||||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
|
|
||||||
else:
|
|
||||||
loss_dict["loss"] = loss_dict["l1"]
|
|
||||||
return loss_dict
|
|
||||||
else:
|
|
||||||
action, _ = self._forward(batch["observation.state"], images)
|
|
||||||
return action
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""A protocol that all policies should follow.
|
||||||
|
|
||||||
|
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||||
|
subclass a base class.
|
||||||
|
|
||||||
|
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||||
|
how to implement new policies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Policy(Protocol):
|
||||||
|
"""The required interface for implementing a policy."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""To be called whenever the environment is reset.
|
||||||
|
|
||||||
|
Does things like clearing caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor], **kwargs):
|
||||||
|
"""Wired to `select_action`."""
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, Tensor], **kwargs):
|
||||||
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
with caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
|
||||||
|
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||||
|
|
||||||
|
def update(self, batch, **kwargs):
|
||||||
|
"""Does compute_loss then an optimization step.
|
||||||
|
|
||||||
|
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||||
|
disappear.
|
||||||
|
"""
|
Loading…
Reference in New Issue