733 lines
28 KiB
Python
733 lines
28 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""
|
||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||
|
||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||
|
||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||
|
||
Install pi0 extra dependencies:
|
||
```bash
|
||
pip install -e ".[pi0]"
|
||
```
|
||
|
||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||
```bash
|
||
python lerobot/scripts/train.py \
|
||
--policy.path=lerobot/pi0 \
|
||
--dataset.repo_id=danaaubakirova/koch_test
|
||
```
|
||
|
||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||
pretrained with VLM default parameters before pi0 finetuning:
|
||
```bash
|
||
python lerobot/scripts/train.py \
|
||
--policy.type=pi0 \
|
||
--dataset.repo_id=danaaubakirova/koch_test
|
||
```
|
||
|
||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||
```python
|
||
policy = Pi0Policy.from_pretrained("lerobot/pi0")
|
||
```
|
||
|
||
"""
|
||
|
||
import math
|
||
from collections import deque
|
||
|
||
import torch
|
||
import torch.nn.functional as F # noqa: N812
|
||
from torch import Tensor, nn
|
||
from transformers import AutoTokenizer
|
||
|
||
from lerobot.common.constants import ACTION, OBS_STATE
|
||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||
PaliGemmaWithExpertConfig,
|
||
PaliGemmaWithExpertModel,
|
||
)
|
||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||
from lerobot.common.utils.utils import get_safe_dtype
|
||
|
||
|
||
def create_sinusoidal_pos_embedding(
|
||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||
) -> Tensor:
|
||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||
if dimension % 2 != 0:
|
||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||
|
||
if time.ndim != 1:
|
||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||
|
||
dtype = get_safe_dtype(torch.float64, device.type)
|
||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||
period = min_period * (max_period / min_period) ** fraction
|
||
|
||
# Compute the outer product
|
||
scaling_factor = 1.0 / period * 2 * math.pi
|
||
sin_input = scaling_factor[None, :] * time[:, None]
|
||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||
return pos_emb
|
||
|
||
|
||
def sample_beta(alpha, beta, bsize, device):
|
||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||
return gamma1 / (gamma1 + gamma2)
|
||
|
||
|
||
def make_att_2d_masks(pad_masks, att_masks):
|
||
"""Copied from big_vision.
|
||
|
||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||
setup several types of attention, for example:
|
||
|
||
[[1 1 1 1 1 1]]: pure causal attention.
|
||
|
||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||
themselves and the last 3 tokens have a causal attention. The first
|
||
entry could also be a 1 without changing behaviour.
|
||
|
||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||
block can attend all previous blocks and all tokens on the same block.
|
||
|
||
Args:
|
||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||
it and 0 where it shares the same attention mask as the previous token.
|
||
"""
|
||
if att_masks.ndim != 2:
|
||
raise ValueError(att_masks.ndim)
|
||
if pad_masks.ndim != 2:
|
||
raise ValueError(pad_masks.ndim)
|
||
|
||
cumsum = torch.cumsum(att_masks, dim=1)
|
||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||
return att_2d_masks
|
||
|
||
|
||
def resize_with_pad(img, width, height, pad_value=-1):
|
||
# assume no-op when width height fits already
|
||
if img.ndim != 4:
|
||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||
|
||
cur_height, cur_width = img.shape[2:]
|
||
|
||
ratio = max(cur_width / width, cur_height / height)
|
||
resized_height = int(cur_height / ratio)
|
||
resized_width = int(cur_width / ratio)
|
||
resized_img = F.interpolate(
|
||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||
)
|
||
|
||
pad_height = max(0, int(height - resized_height))
|
||
pad_width = max(0, int(width - resized_width))
|
||
|
||
# pad on left and top of image
|
||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||
return padded_img
|
||
|
||
|
||
def pad_vector(vector, new_dim):
|
||
"""Can be (batch_size x sequence_length x features_dimension)
|
||
or (batch_size x features_dimension)
|
||
"""
|
||
if vector.shape[-1] == new_dim:
|
||
return vector
|
||
shape = list(vector.shape)
|
||
current_dim = shape[-1]
|
||
shape[-1] = new_dim
|
||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||
new_vector[..., :current_dim] = vector
|
||
return new_vector
|
||
|
||
|
||
def normalize(x, min_val, max_val):
|
||
return (x - min_val) / (max_val - min_val)
|
||
|
||
|
||
def unnormalize(x, min_val, max_val):
|
||
return x * (max_val - min_val) + min_val
|
||
|
||
|
||
def safe_arcsin(value):
|
||
# This ensures that the input stays within
|
||
# [−1,1] to avoid invalid values for arcsin
|
||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||
|
||
|
||
def aloha_gripper_to_angular(value):
|
||
# Aloha transforms the gripper positions into a linear space. The following code
|
||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||
# angular space.
|
||
#
|
||
# These values are coming from the Aloha code:
|
||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||
|
||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||
return safe_arcsin(value)
|
||
|
||
# The constants are taken from the Interbotix code.
|
||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||
|
||
# Normalize to [0, 1].
|
||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||
return normalize(value, min_val=0.4, max_val=1.5)
|
||
|
||
|
||
def aloha_gripper_from_angular(value):
|
||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||
# Note that the units are still angular but the range is different.
|
||
|
||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||
|
||
# These values are coming from the Aloha code:
|
||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||
|
||
|
||
def aloha_gripper_from_angular_inv(value):
|
||
# Directly inverts the gripper_from_angular function.
|
||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||
return normalize(value, min_val=0.4, max_val=1.5)
|
||
|
||
|
||
class PI0Policy(PreTrainedPolicy):
|
||
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
|
||
|
||
config_class = PI0Config
|
||
name = "pi0"
|
||
|
||
def __init__(
|
||
self,
|
||
config: PI0Config,
|
||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||
the configuration class is used.
|
||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||
"""
|
||
|
||
super().__init__(config)
|
||
config.validate_features()
|
||
self.config = config
|
||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||
self.normalize_targets = Normalize(
|
||
config.output_features, config.normalization_mapping, dataset_stats
|
||
)
|
||
self.unnormalize_outputs = Unnormalize(
|
||
config.output_features, config.normalization_mapping, dataset_stats
|
||
)
|
||
|
||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||
self.model = PI0FlowMatching(config)
|
||
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
"""This should be called whenever the environment is reset."""
|
||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||
|
||
def get_optim_params(self) -> dict:
|
||
return self.parameters()
|
||
|
||
@torch.no_grad
|
||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||
"""Select a single action given environment observations.
|
||
|
||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||
queue is empty.
|
||
"""
|
||
self.eval()
|
||
|
||
if self.config.adapt_to_pi_aloha:
|
||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||
|
||
batch = self.normalize_inputs(batch)
|
||
|
||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||
# querying the policy.
|
||
if len(self._action_queue) == 0:
|
||
images, img_masks = self.prepare_images(batch)
|
||
state = self.prepare_state(batch)
|
||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||
|
||
actions = self.model.sample_actions(
|
||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||
)
|
||
|
||
# Unpad actions
|
||
original_action_dim = self.config.action_feature.shape[0]
|
||
actions = actions[:, :, :original_action_dim]
|
||
|
||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||
|
||
if self.config.adapt_to_pi_aloha:
|
||
actions = self._pi_aloha_encode_actions(actions)
|
||
|
||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||
self._action_queue.extend(actions.transpose(0, 1))
|
||
return self._action_queue.popleft()
|
||
|
||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||
"""Do a full training forward pass to compute the loss"""
|
||
if self.config.adapt_to_pi_aloha:
|
||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||
|
||
batch = self.normalize_inputs(batch)
|
||
batch = self.normalize_targets(batch)
|
||
|
||
images, img_masks = self.prepare_images(batch)
|
||
state = self.prepare_state(batch)
|
||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||
actions = self.prepare_action(batch)
|
||
actions_is_pad = batch.get("action_is_pad")
|
||
|
||
loss_dict = {}
|
||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||
loss_dict["losses_after_forward"] = losses.clone()
|
||
|
||
if actions_is_pad is not None:
|
||
in_episode_bound = ~actions_is_pad
|
||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||
|
||
# Remove padding
|
||
losses = losses[:, :, : self.config.max_action_dim]
|
||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||
|
||
# For backward pass
|
||
loss = losses.mean()
|
||
# For logging
|
||
loss_dict["l2_loss"] = loss.item()
|
||
|
||
return loss, loss_dict
|
||
|
||
def prepare_images(self, batch):
|
||
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||
"""
|
||
images = []
|
||
img_masks = []
|
||
|
||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||
|
||
if len(present_img_keys) == 0:
|
||
raise ValueError(
|
||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||
)
|
||
|
||
# Preprocess image features present in the batch
|
||
for key in present_img_keys:
|
||
img = batch[key]
|
||
|
||
if self.config.resize_imgs_with_padding is not None:
|
||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||
|
||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||
img = img * 2.0 - 1.0
|
||
|
||
bsize = img.shape[0]
|
||
device = img.device
|
||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||
images.append(img)
|
||
img_masks.append(mask)
|
||
|
||
# Create image features not present in the batch
|
||
# as fully 0 padded images.
|
||
for num_empty_cameras in range(len(missing_img_keys)):
|
||
if num_empty_cameras >= self.config.empty_cameras:
|
||
break
|
||
img = torch.ones_like(img) * -1
|
||
mask = torch.zeros_like(mask)
|
||
images.append(img)
|
||
img_masks.append(mask)
|
||
|
||
return images, img_masks
|
||
|
||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||
"""Tokenize the text input"""
|
||
device = batch[OBS_STATE].device
|
||
tasks = batch["task"]
|
||
|
||
# PaliGemma prompt has to end with a new line
|
||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||
|
||
tokenized_prompt = self.language_tokenizer.__call__(
|
||
tasks,
|
||
padding="max_length",
|
||
padding_side="right",
|
||
max_length=self.config.tokenizer_max_length,
|
||
return_tensors="pt",
|
||
)
|
||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||
|
||
return lang_tokens, lang_masks
|
||
|
||
def _pi_aloha_decode_state(self, state):
|
||
# Flip the joints.
|
||
for motor_idx in [1, 2, 8, 9]:
|
||
state[:, motor_idx] *= -1
|
||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||
for motor_idx in [6, 13]:
|
||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||
return state
|
||
|
||
def _pi_aloha_encode_actions(self, actions):
|
||
# Flip the joints.
|
||
for motor_idx in [1, 2, 8, 9]:
|
||
actions[:, :, motor_idx] *= -1
|
||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||
for motor_idx in [6, 13]:
|
||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||
return actions
|
||
|
||
def _pi_aloha_encode_actions_inv(self, actions):
|
||
# Flip the joints again.
|
||
for motor_idx in [1, 2, 8, 9]:
|
||
actions[:, :, motor_idx] *= -1
|
||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||
for motor_idx in [6, 13]:
|
||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||
return actions
|
||
|
||
def prepare_state(self, batch):
|
||
"""Pad state"""
|
||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||
return state
|
||
|
||
def prepare_action(self, batch):
|
||
"""Pad action"""
|
||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||
return actions
|
||
|
||
|
||
class PI0FlowMatching(nn.Module):
|
||
"""
|
||
π0: A Vision-Language-Action Flow Model for General Robot Control
|
||
|
||
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||
|
||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||
┌──────────────────────────────┐
|
||
│ actions │
|
||
│ ▲ │
|
||
│ ┌┴─────┐ │
|
||
│ kv cache │Gemma │ │
|
||
│ ┌──────────►│Expert│ │
|
||
│ │ │ │ │
|
||
│ ┌┴────────┐ │x 10 │ │
|
||
│ │ │ └▲──▲──┘ │
|
||
│ │PaliGemma│ │ │ │
|
||
│ │ │ │ robot state │
|
||
│ │ │ noise │
|
||
│ └▲──▲─────┘ │
|
||
│ │ │ │
|
||
│ │ image(s) │
|
||
│ language tokens │
|
||
└──────────────────────────────┘
|
||
"""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
|
||
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||
train_expert_only=self.config.train_expert_only,
|
||
attention_implementation=self.config.attention_implementation,
|
||
)
|
||
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
||
|
||
# Projections are float32
|
||
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
||
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
||
|
||
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
||
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
||
|
||
self.set_requires_grad()
|
||
|
||
def set_requires_grad(self):
|
||
for params in self.state_proj.parameters():
|
||
params.requires_grad = self.config.train_state_proj
|
||
|
||
def sample_noise(self, shape, device):
|
||
noise = torch.normal(
|
||
mean=0.0,
|
||
std=1.0,
|
||
size=shape,
|
||
dtype=torch.float32,
|
||
device=device,
|
||
)
|
||
return noise
|
||
|
||
def sample_time(self, bsize, device):
|
||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||
time = time_beta * 0.999 + 0.001
|
||
return time.to(dtype=torch.float32, device=device)
|
||
|
||
def embed_prefix(
|
||
self, images, img_masks, lang_tokens, lang_masks
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||
for PaliGemma transformer processing.
|
||
"""
|
||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||
embs = []
|
||
pad_masks = []
|
||
att_masks = []
|
||
|
||
# TODO: remove for loop
|
||
for (
|
||
img,
|
||
img_mask,
|
||
) in zip(images, img_masks, strict=False):
|
||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||
|
||
# Normalize image embeddings
|
||
img_emb_dim = img_emb.shape[-1]
|
||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||
|
||
bsize, num_img_embs = img_emb.shape[:2]
|
||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||
|
||
embs.append(img_emb)
|
||
pad_masks.append(img_mask)
|
||
|
||
# Create attention masks so that image tokens attend to each other
|
||
att_masks += [0] * num_img_embs
|
||
|
||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||
|
||
# Normalize language embeddings
|
||
lang_emb_dim = lang_emb.shape[-1]
|
||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||
|
||
embs.append(lang_emb)
|
||
pad_masks.append(lang_masks)
|
||
|
||
# full attention between image and language inputs
|
||
num_lang_embs = lang_emb.shape[1]
|
||
att_masks += [0] * num_lang_embs
|
||
|
||
embs = torch.cat(embs, dim=1)
|
||
pad_masks = torch.cat(pad_masks, dim=1)
|
||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||
|
||
return embs, pad_masks, att_masks
|
||
|
||
def embed_suffix(self, state, noisy_actions, timestep):
|
||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||
embs = []
|
||
pad_masks = []
|
||
att_masks = []
|
||
|
||
# Embed state
|
||
state_emb = self.state_proj(state)
|
||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||
embs.append(state_emb[:, None, :])
|
||
bsize = state_emb.shape[0]
|
||
dtype = state_emb.dtype
|
||
device = state_emb.device
|
||
|
||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||
pad_masks.append(state_mask)
|
||
|
||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||
att_masks += [1]
|
||
|
||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||
time_emb = create_sinusoidal_pos_embedding(
|
||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||
)
|
||
time_emb = time_emb.type(dtype=dtype)
|
||
|
||
# Fuse timestep + action information using an MLP
|
||
action_emb = self.action_in_proj(noisy_actions)
|
||
|
||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||
|
||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||
|
||
# Add to input tokens
|
||
embs.append(action_time_emb)
|
||
|
||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||
pad_masks.append(action_time_mask)
|
||
|
||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
||
|
||
embs = torch.cat(embs, dim=1)
|
||
pad_masks = torch.cat(pad_masks, dim=1)
|
||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||
|
||
return embs, pad_masks, att_masks
|
||
|
||
def forward(
|
||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||
) -> Tensor:
|
||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||
if noise is None:
|
||
noise = self.sample_noise(actions.shape, actions.device)
|
||
|
||
if time is None:
|
||
time = self.sample_time(actions.shape[0], actions.device)
|
||
|
||
time_expanded = time[:, None, None]
|
||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||
u_t = noise - actions
|
||
|
||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||
images, img_masks, lang_tokens, lang_masks
|
||
)
|
||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
||
|
||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||
|
||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||
|
||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||
attention_mask=att_2d_masks,
|
||
position_ids=position_ids,
|
||
past_key_values=None,
|
||
inputs_embeds=[prefix_embs, suffix_embs],
|
||
use_cache=False,
|
||
fill_kv_cache=False,
|
||
)
|
||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||
# Original openpi code, upcast attention output
|
||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||
v_t = self.action_out_proj(suffix_out)
|
||
|
||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||
return losses
|
||
|
||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||
bsize = state.shape[0]
|
||
device = state.device
|
||
|
||
if noise is None:
|
||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||
noise = self.sample_noise(actions_shape, device)
|
||
|
||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||
images, img_masks, lang_tokens, lang_masks
|
||
)
|
||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||
|
||
# Compute image and language key value cache
|
||
_, past_key_values = self.paligemma_with_expert.forward(
|
||
attention_mask=prefix_att_2d_masks,
|
||
position_ids=prefix_position_ids,
|
||
past_key_values=None,
|
||
inputs_embeds=[prefix_embs, None],
|
||
use_cache=self.config.use_cache,
|
||
fill_kv_cache=True,
|
||
)
|
||
|
||
dt = -1.0 / self.config.num_steps
|
||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||
|
||
x_t = noise
|
||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||
while time >= -dt / 2:
|
||
expanded_time = time.expand(bsize)
|
||
v_t = self.denoise_step(
|
||
state,
|
||
prefix_pad_masks,
|
||
past_key_values,
|
||
x_t,
|
||
expanded_time,
|
||
)
|
||
|
||
# Euler step
|
||
x_t += dt * v_t
|
||
time += dt
|
||
return x_t
|
||
|
||
def denoise_step(
|
||
self,
|
||
state,
|
||
prefix_pad_masks,
|
||
past_key_values,
|
||
x_t,
|
||
timestep,
|
||
):
|
||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
||
|
||
suffix_len = suffix_pad_masks.shape[1]
|
||
batch_size = prefix_pad_masks.shape[0]
|
||
prefix_len = prefix_pad_masks.shape[1]
|
||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||
|
||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||
|
||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||
|
||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||
|
||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||
attention_mask=full_att_2d_masks,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
inputs_embeds=[None, suffix_embs],
|
||
use_cache=self.config.use_cache,
|
||
fill_kv_cache=False,
|
||
)
|
||
suffix_out = outputs_embeds[1]
|
||
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||
v_t = self.action_out_proj(suffix_out)
|
||
return v_t
|