clean and refactor pi0fast

This commit is contained in:
mshukor 2025-03-31 16:28:39 +02:00
parent 7a45fa0fc1
commit 6e48e044d7
3 changed files with 37 additions and 131 deletions

View File

@ -72,7 +72,7 @@
0.95
],
"optimizer_eps": 1e-08,
"optimizer_weight_decay": 1e-10,
"optimizer_weight_decay": 1e-5,
"scheduler_warmup_steps": 1000,
"scheduler_decay_steps": 30000,
"scheduler_decay_lr": 2.5e-06

View File

@ -8,13 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("pi0fast")
@dataclass
@ -71,10 +64,10 @@ class PI0FASTConfig(PreTrainedConfig):
freeze_lm_head: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
@ -85,15 +78,8 @@ class PI0FASTConfig(PreTrainedConfig):
padding_side: str = "right"
# peft_method: str = ""
# peft_config: PEFTConfig = PEFTConfig()
precision: str = "bfloat16"
attention_mode: str = "prefix"
action_kw_to_prefix: bool = True
# TODO: Add EMA
grad_clip_norm: float = 1
def __post_init__(self):
super().__post_init__()
@ -110,10 +96,6 @@ class PI0FASTConfig(PreTrainedConfig):
)
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
@ -128,6 +110,7 @@ class PI0FASTConfig(PreTrainedConfig):
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self):

View File

@ -50,7 +50,6 @@ from functools import partial
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
# from peft import LoraConfig, TaskType, get_peft_model
from PIL import Image
from scipy.fft import idct
from torch import Tensor, nn
@ -58,7 +57,6 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
@ -175,7 +173,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
self.reset()
@ -223,7 +220,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
"""
self.eval()
if self.adapt_to_pi_aloha:
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
@ -242,7 +239,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.adapt_to_pi_aloha:
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -251,7 +248,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.adapt_to_pi_aloha:
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
@ -275,8 +272,6 @@ def block_causal_update_causal_mask(
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# dtype = self.pi0_paligemma.dtype
# is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(dtype).min
@ -398,7 +393,7 @@ def prepare_inputs_for_generation(
**kwargs,
)
# position_ids in Paligemma are 1-indexed
# Position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
@ -491,19 +486,17 @@ class PI0FAST(nn.Module):
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# self.pi0_paligemma = self.configure_peft(pi0_paligemma)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
print(f"Cast model params to {precision}")
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
@ -518,32 +511,6 @@ class PI0FAST(nn.Module):
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
params.requires_grad = False
# def configure_peft(self, model):
# self.peft_method = self.config.peft_method
# if "lora" in self.peft_method:
# peft_config = self.config.peft_config
# target_modules = peft_config.target_modules
# if not isinstance(target_modules, list):
# target_modules = target_modules.split(",")
# lora_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
# r=peft_config.r, # The rank of the low-rank adaptation
# lora_alpha=peft_config.lora_alpha, # Scaling factor
# lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
# target_modules=target_modules, # The components where LoRA is applied
# )
# self.lora_config = lora_config
# model = get_peft_model(model, lora_config)
# for name, param in model.named_parameters():
# if (
# "lora" in name
# ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
# param.requires_grad = True
# else:
# param.requires_grad = False
# return model
def embed_tokens(self, tokens: torch.Tensor):
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
@ -554,11 +521,7 @@ class PI0FAST(nn.Module):
"""Preprocess LeRobot batch into Pi0 inputs"""
images = []
img_masks = []
img_keys = sorted(self.config.image_features.keys(), key=lambda k: IMAGES_ORDER.get(k, float("inf")))
present_img_keys = [key for key in self.config.image_features if key in batch]
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
# present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")))
present_img_keys = [key for key in self.image_keys if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
@ -566,7 +529,7 @@ class PI0FAST(nn.Module):
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in img_keys:
for key in self.image_keys:
if key in present_img_keys:
img = batch[key]
@ -586,14 +549,11 @@ class PI0FAST(nn.Module):
mask = torch.ones(bsize, dtype=torch.bool, device=device)
else:
if num_empty_cameras >= self.config.empty_cameras:
break
continue
img = torch.ones_like(img) * -1
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device) # FIXME(mshukor): similar to openpi, but should be zeros?
# mask = torch.zeros(bsize, dtype=torch.bool, device=device)
# mask = torch.zeros_like(img)
# mask = torch.ones_like(mask)
mask = torch.ones(bsize, dtype=torch.bool, device=device)
num_empty_cameras += 1
images.append(img)
@ -620,31 +580,21 @@ class PI0FAST(nn.Module):
return fast_out
def create_token_type_ids(
self, padded_mask: torch.Tensor, prefix_len: int, action_kw_len: int, state_len: torch.Tensor, mode: str = "prefix"
self, padded_mask: torch.Tensor, prefix_len: int
) -> torch.Tensor:
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
# Compute cumulative sum mask
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
if mode == "block_causal":
# Start of state (only one position)
start_state_mask = cumsum_mask == (prefix_len - (action_kw_len + state_len))
# Start of action (only one position)
start_action_mask = cumsum_mask >= (prefix_len - action_kw_len)
# Combine the masks
token_type_ids = suffix_mask | start_state_mask | start_action_mask
else:
token_type_ids = suffix_mask
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens(self, state, lang_text, actions=None, action_kw_to_prefix: bool = True):
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
# TODO remove hardcoded parameter (32)
# discretized = F.pad(discretized, (0, max(0, 32 - discretized.shape[1])), value=0)[:, :32] # FIXME(mshukor): debug
discretized = discretized[:, :32]
prefix_texts = []
@ -652,10 +602,7 @@ class PI0FAST(nn.Module):
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
if action_kw_to_prefix:
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\nAction:")
else:
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
@ -665,11 +612,6 @@ class PI0FAST(nn.Module):
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
state_lens = self.paligemma_tokenizer(
state_text, add_special_tokens=False, return_tensors="pt", padding="longest", truncation=False
).attention_mask.sum(1)[:, None]
action_kw_len = torch.tensor([2])[:, None] if action_kw_to_prefix else torch.tensor([0])[:, None] # corresponds to["Action:"]
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
@ -682,7 +624,7 @@ class PI0FAST(nn.Module):
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
# replace action with 0 to pad tokens
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
@ -693,21 +635,16 @@ class PI0FAST(nn.Module):
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
if action_kw_to_prefix:
act_ids = torch.cat([act_ids, eos_token], dim=1)
act_mask = torch.cat([act_mask, eos_mask], dim=1)
else:
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
#eos_mask = torch.ones_like(eos_token)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1) # act_ids already include postfix
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
@ -721,10 +658,10 @@ class PI0FAST(nn.Module):
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(
dim=1
) > prefix_lens # [:, None].to(padded_mask.device) # need a batch indicator of prefix lengths OR NOT
) > prefix_lens
token_type_ids = self.create_token_type_ids(
padded_mask=padded_mask, prefix_len=prefix_lens, action_kw_len=action_kw_len, state_len=state_lens, mode=self.config.attention_mode
padded_mask=padded_mask, prefix_len=prefix_lens
)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
@ -776,10 +713,10 @@ class PI0FAST(nn.Module):
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], action_kw_to_prefix=self.config.action_kw_to_prefix,
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION],
)
embs, pad_masks, att_masks, targets, loss_mask, token_type_ids = self.embed_inputs(
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
padded_outs["input_ids"],
@ -842,6 +779,9 @@ class PI0FAST(nn.Module):
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self.time_horizon = (
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
)
@ -905,10 +845,9 @@ class PI0FAST(nn.Module):
# Decode predicted output tokens
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
cleaned_tokens = [
tokens_sequence.replace(":", "").strip().split("|")[0].strip()
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
] # should work
# TODO: for now let's use the processor tokenizer which encodes in the way we want (it is different from tusing the AutoTokenizer for some reasons)
]
raw_action_tokens = [
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
@ -936,7 +875,7 @@ class PI0FAST(nn.Module):
# TODO: keep like this or move to the policy .forward
images, img_masks = self.prepare_images(batch)
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None, action_kw_to_prefix=self.config.action_kw_to_prefix)
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
images,
img_masks,
@ -954,18 +893,16 @@ class PI0FAST(nn.Module):
attention_mask=pad_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=embs, # No need for [prefix_embs, None]
inputs_embeds=embs,
use_cache=self.config.use_cache,
max_new_tokens=self.config.max_decoding_steps,
do_sample=False,
num_beams=1,
token_type_ids=token_type_ids,
)
# import ipdb; ipdb.set_trace()
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
return actions
# TODO: remove? seems uneeded
def embed_image(self, image: torch.Tensor):
return self.pi0_paligemma.get_image_features(image)
@ -1059,18 +996,4 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
return padded_img