clean and refactor pi0fast

This commit is contained in:
mshukor 2025-03-31 16:28:39 +02:00
parent 7a45fa0fc1
commit 6e48e044d7
3 changed files with 37 additions and 131 deletions

View File

@ -72,7 +72,7 @@
0.95 0.95
], ],
"optimizer_eps": 1e-08, "optimizer_eps": 1e-08,
"optimizer_weight_decay": 1e-10, "optimizer_weight_decay": 1e-5,
"scheduler_warmup_steps": 1000, "scheduler_warmup_steps": 1000,
"scheduler_decay_steps": 30000, "scheduler_decay_steps": 30000,
"scheduler_decay_lr": 2.5e-06 "scheduler_decay_lr": 2.5e-06

View File

@ -8,13 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("pi0fast") @PreTrainedConfig.register_subclass("pi0fast")
@dataclass @dataclass
@ -71,10 +64,10 @@ class PI0FASTConfig(PreTrainedConfig):
freeze_lm_head: bool = True freeze_lm_head: bool = True
# Training presets # Training presets
optimizer_lr: float = 2.5e-5 optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8 optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10 optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000 scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000 scheduler_decay_steps: int = 30_000
@ -85,15 +78,8 @@ class PI0FASTConfig(PreTrainedConfig):
padding_side: str = "right" padding_side: str = "right"
# peft_method: str = ""
# peft_config: PEFTConfig = PEFTConfig()
precision: str = "bfloat16" precision: str = "bfloat16"
attention_mode: str = "prefix" grad_clip_norm: float = 1
action_kw_to_prefix: bool = True
# TODO: Add EMA
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@ -110,10 +96,6 @@ class PI0FASTConfig(PreTrainedConfig):
) )
def validate_features(self) -> None: def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}" key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
@ -128,6 +110,7 @@ class PI0FASTConfig(PreTrainedConfig):
betas=self.optimizer_betas, betas=self.optimizer_betas,
eps=self.optimizer_eps, eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay, weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
) )
def get_scheduler_preset(self): def get_scheduler_preset(self):

View File

@ -50,7 +50,6 @@ from functools import partial
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
# from peft import LoraConfig, TaskType, get_peft_model
from PIL import Image from PIL import Image
from scipy.fft import idct from scipy.fft import idct
from torch import Tensor, nn from torch import Tensor, nn
@ -58,7 +57,6 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -175,7 +173,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config) self.model = PI0FAST(config)
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
self.reset() self.reset()
@ -223,7 +220,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
""" """
self.eval() self.eval()
if self.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -242,7 +239,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions) actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -251,7 +248,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
@ -275,8 +272,6 @@ def block_causal_update_causal_mask(
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
# dtype = self.pi0_paligemma.dtype
# is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
@ -398,7 +393,7 @@ def prepare_inputs_for_generation(
**kwargs, **kwargs,
) )
# position_ids in Paligemma are 1-indexed # Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None: if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1 model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
@ -491,19 +486,17 @@ class PI0FAST(nn.Module):
self.pi0_paligemma.prepare_inputs_for_generation = partial( self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma prepare_inputs_for_generation, self=self.pi0_paligemma
) )
# self.pi0_paligemma = self.configure_peft(pi0_paligemma)
# change important stuff in bf16 # change important stuff in bf16
params_to_change_dtype = [ params_to_change_dtype = [
"language_model", "language_model",
"vision_tower", "vision_tower",
"multi_modal", "multi_modal",
] ]
print(f"Cast model params to {precision}")
for name, param in self.pi0_paligemma.named_parameters(): for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype): if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision) param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad() self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side self.padding_side = self.config.padding_side
@ -518,32 +511,6 @@ class PI0FAST(nn.Module):
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
params.requires_grad = False params.requires_grad = False
# def configure_peft(self, model):
# self.peft_method = self.config.peft_method
# if "lora" in self.peft_method:
# peft_config = self.config.peft_config
# target_modules = peft_config.target_modules
# if not isinstance(target_modules, list):
# target_modules = target_modules.split(",")
# lora_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
# r=peft_config.r, # The rank of the low-rank adaptation
# lora_alpha=peft_config.lora_alpha, # Scaling factor
# lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
# target_modules=target_modules, # The components where LoRA is applied
# )
# self.lora_config = lora_config
# model = get_peft_model(model, lora_config)
# for name, param in model.named_parameters():
# if (
# "lora" in name
# ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
# param.requires_grad = True
# else:
# param.requires_grad = False
# return model
def embed_tokens(self, tokens: torch.Tensor): def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens) return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
@ -554,11 +521,7 @@ class PI0FAST(nn.Module):
"""Preprocess LeRobot batch into Pi0 inputs""" """Preprocess LeRobot batch into Pi0 inputs"""
images = [] images = []
img_masks = [] img_masks = []
img_keys = sorted(self.config.image_features.keys(), key=lambda k: IMAGES_ORDER.get(k, float("inf"))) present_img_keys = [key for key in self.image_keys if key in batch]
present_img_keys = [key for key in self.config.image_features if key in batch]
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
# present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")))
if len(present_img_keys) == 0: if len(present_img_keys) == 0:
raise ValueError( raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
@ -566,7 +529,7 @@ class PI0FAST(nn.Module):
# Preprocess image features present in the batch # Preprocess image features present in the batch
num_empty_cameras = 0 num_empty_cameras = 0
for key in img_keys: for key in self.image_keys:
if key in present_img_keys: if key in present_img_keys:
img = batch[key] img = batch[key]
@ -586,14 +549,11 @@ class PI0FAST(nn.Module):
mask = torch.ones(bsize, dtype=torch.bool, device=device) mask = torch.ones(bsize, dtype=torch.bool, device=device)
else: else:
if num_empty_cameras >= self.config.empty_cameras: if num_empty_cameras >= self.config.empty_cameras:
break continue
img = torch.ones_like(img) * -1 img = torch.ones_like(img) * -1
bsize = img.shape[0] bsize = img.shape[0]
device = img.device device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device) # FIXME(mshukor): similar to openpi, but should be zeros? mask = torch.ones(bsize, dtype=torch.bool, device=device)
# mask = torch.zeros(bsize, dtype=torch.bool, device=device)
# mask = torch.zeros_like(img)
# mask = torch.ones_like(mask)
num_empty_cameras += 1 num_empty_cameras += 1
images.append(img) images.append(img)
@ -620,31 +580,21 @@ class PI0FAST(nn.Module):
return fast_out return fast_out
def create_token_type_ids( def create_token_type_ids(
self, padded_mask: torch.Tensor, prefix_len: int, action_kw_len: int, state_len: torch.Tensor, mode: str = "prefix" self, padded_mask: torch.Tensor, prefix_len: int
) -> torch.Tensor: ) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask # Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1) cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len) # Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len suffix_mask = cumsum_mask > prefix_len
if mode == "block_causal":
# Start of state (only one position)
start_state_mask = cumsum_mask == (prefix_len - (action_kw_len + state_len))
# Start of action (only one position)
start_action_mask = cumsum_mask >= (prefix_len - action_kw_len)
# Combine the masks
token_type_ids = suffix_mask | start_state_mask | start_action_mask
else:
token_type_ids = suffix_mask token_type_ids = suffix_mask
return token_type_ids return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None, action_kw_to_prefix: bool = True): def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0] bsize = state.shape[0]
device = state.device device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1 discretized = torch.bucketize(state, bins) - 1
# TODO remove hardcoded parameter (32)
# discretized = F.pad(discretized, (0, max(0, 32 - discretized.shape[1])), value=0)[:, :32] # FIXME(mshukor): debug
discretized = discretized[:, :32] discretized = discretized[:, :32]
prefix_texts = [] prefix_texts = []
@ -652,9 +602,6 @@ class PI0FAST(nn.Module):
for txt, disc in zip(lang_text, discretized, strict=False): for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ") cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc) state_str = " ".join(str(val.item()) for val in disc)
if action_kw_to_prefix:
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\nAction:")
else:
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n") state_text.append(f"State: {state_str};\n")
@ -665,11 +612,6 @@ class PI0FAST(nn.Module):
prefix_mask = prefix_out["attention_mask"].to(device) prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
state_lens = self.paligemma_tokenizer(
state_text, add_special_tokens=False, return_tensors="pt", padding="longest", truncation=False
).attention_mask.sum(1)[:, None]
action_kw_len = torch.tensor([2])[:, None] if action_kw_to_prefix else torch.tensor([0])[:, None] # corresponds to["Action:"]
if actions is not None: if actions is not None:
actions_norm = self.normalize_actions(actions) actions_norm = self.normalize_actions(actions)
actions_pad = F.pad( actions_pad = F.pad(
@ -682,7 +624,7 @@ class PI0FAST(nn.Module):
act_mask = fast_out["attention_mask"].to(device) act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device) act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
# replace action with 0 to pad tokens # Replace action with 0 to pad tokens
act_ids = torch.where( act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id, self.pad_token_id,
@ -693,21 +635,16 @@ class PI0FAST(nn.Module):
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1) ).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
if action_kw_to_prefix:
act_ids = torch.cat([act_ids, eos_token], dim=1)
act_mask = torch.cat([act_mask, eos_mask], dim=1)
else:
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt') bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device) bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device) bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
#eos_mask = torch.ones_like(eos_token)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device) act_mask = act_mask.to(device)
else: else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1) # act_ids already include postfix final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1) final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
@ -721,10 +658,10 @@ class PI0FAST(nn.Module):
# define tensor of padding lengths # define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum( att_mask = (padded_mask != 0).cumsum(
dim=1 dim=1
) > prefix_lens # [:, None].to(padded_mask.device) # need a batch indicator of prefix lengths OR NOT ) > prefix_lens
token_type_ids = self.create_token_type_ids( token_type_ids = self.create_token_type_ids(
padded_mask=padded_mask, prefix_len=prefix_lens, action_kw_len=action_kw_len, state_len=state_lens, mode=self.config.attention_mode padded_mask=padded_mask, prefix_len=prefix_lens
) )
padded_output["padded_mask"] = padded_output.pop("attention_mask") padded_output["padded_mask"] = padded_output.pop("attention_mask")
@ -776,10 +713,10 @@ class PI0FAST(nn.Module):
images, img_masks = self.prepare_images(batch) images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens( padded_outs = self.create_input_tokens(
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], action_kw_to_prefix=self.config.action_kw_to_prefix, state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION],
) )
embs, pad_masks, att_masks, targets, loss_mask, token_type_ids = self.embed_inputs( embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images, images,
img_masks, img_masks,
padded_outs["input_ids"], padded_outs["input_ids"],
@ -842,6 +779,9 @@ class PI0FAST(nn.Module):
time_horizon: int | None = None, time_horizon: int | None = None,
action_dim: int | None = None, action_dim: int | None = None,
) -> np.array: ) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = ( self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
) )
@ -905,10 +845,9 @@ class PI0FAST(nn.Module):
# Decode predicted output tokens # Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [ cleaned_tokens = [
tokens_sequence.replace(":", "").strip().split("|")[0].strip() tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens for tokens_sequence in decoded_tokens
] # should work ]
# TODO: for now let's use the processor tokenizer which encodes in the way we want (it is different from tusing the AutoTokenizer for some reasons)
raw_action_tokens = [ raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens for sample_tokens in cleaned_tokens
@ -936,7 +875,7 @@ class PI0FAST(nn.Module):
# TODO: keep like this or move to the policy .forward # TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch) images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None, action_kw_to_prefix=self.config.action_kw_to_prefix) padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images, images,
img_masks, img_masks,
@ -954,18 +893,16 @@ class PI0FAST(nn.Module):
attention_mask=pad_masks, attention_mask=pad_masks,
position_ids=prefix_position_ids, position_ids=prefix_position_ids,
past_key_values=None, past_key_values=None,
inputs_embeds=embs, # No need for [prefix_embs, None] inputs_embeds=embs,
use_cache=self.config.use_cache, use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps, max_new_tokens=self.config.max_decoding_steps,
do_sample=False, do_sample=False,
num_beams=1, num_beams=1,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
) )
# import ipdb; ipdb.set_trace()
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions return actions
# TODO: remove? seems uneeded
def embed_image(self, image: torch.Tensor): def embed_image(self, image: torch.Tensor):
return self.pi0_paligemma.get_image_features(image) return self.pi0_paligemma.get_image_features(image)
@ -1060,17 +997,3 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# pad on left and top of image # pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector