lerobot/lerobot/common/policies/pi0/modeling_pi0.py

733 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Install pi0 extra dependencies:
```bash
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
pretrained with VLM default parameters before pi0 finetuning:
```bash
python lerobot/scripts/train.py \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = Pi0Policy.from_pretrained("lerobot/pi0")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0Policy(PreTrainedPolicy):
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
config_class = PI0Config
name = "pi0"
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
class PI0FlowMatching(nn.Module):
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌┴─────┐ │
│ kv cache │Gemma │ │
│ ┌──────────►│Expert│ │
│ │ │ │ │
│ ┌┴────────┐ │x 10 │ │
│ │ │ └▲──▲──┘ │
│ │PaliGemma│ │ │ │
│ │ │ │ robot state │
│ │ │ noise │
│ └▲──▲─────┘ │
│ │ │ │
│ │ image(s) │
│ language tokens │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
embs = []
pad_masks = []
att_masks = []
# TODO: remove for loop
for (
img,
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
)
time_emb = time_emb.type(dtype=dtype)
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.n_action_steps :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.n_action_steps :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t