clean and refactor pi0fast
This commit is contained in:
parent
7a45fa0fc1
commit
6e48e044d7
|
@ -72,7 +72,7 @@
|
|||
0.95
|
||||
],
|
||||
"optimizer_eps": 1e-08,
|
||||
"optimizer_weight_decay": 1e-10,
|
||||
"optimizer_weight_decay": 1e-5,
|
||||
"scheduler_warmup_steps": 1000,
|
||||
"scheduler_decay_steps": 30000,
|
||||
"scheduler_decay_lr": 2.5e-06
|
||||
|
|
|
@ -8,13 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTConfig:
|
||||
r: int = 4
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
target_modules: str = "q_proj,v_proj"
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@dataclass
|
||||
|
@ -71,10 +64,10 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
freeze_lm_head: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
|
@ -85,15 +78,8 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
|
||||
padding_side: str = "right"
|
||||
|
||||
# peft_method: str = ""
|
||||
# peft_config: PEFTConfig = PEFTConfig()
|
||||
|
||||
precision: str = "bfloat16"
|
||||
attention_mode: str = "prefix"
|
||||
|
||||
action_kw_to_prefix: bool = True
|
||||
|
||||
# TODO: Add EMA
|
||||
grad_clip_norm: float = 1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
@ -110,10 +96,6 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
|
@ -128,6 +110,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
|
|
@ -50,7 +50,6 @@ from functools import partial
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
# from peft import LoraConfig, TaskType, get_peft_model
|
||||
from PIL import Image
|
||||
from scipy.fft import idct
|
||||
from torch import Tensor, nn
|
||||
|
@ -58,7 +57,6 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
|||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
@ -175,7 +173,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -223,7 +220,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
"""
|
||||
self.eval()
|
||||
|
||||
if self.adapt_to_pi_aloha:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
@ -242,7 +239,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.adapt_to_pi_aloha:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
|
@ -251,7 +248,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.adapt_to_pi_aloha:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
@ -275,8 +272,6 @@ def block_causal_update_causal_mask(
|
|||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
# dtype = self.pi0_paligemma.dtype
|
||||
# is_training = is_training if is_training is not None else self.training
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
|
@ -398,7 +393,7 @@ def prepare_inputs_for_generation(
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
# Position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
|
@ -491,19 +486,17 @@ class PI0FAST(nn.Module):
|
|||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
)
|
||||
# self.pi0_paligemma = self.configure_peft(pi0_paligemma)
|
||||
# change important stuff in bf16
|
||||
params_to_change_dtype = [
|
||||
"language_model",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
print(f"Cast model params to {precision}")
|
||||
for name, param in self.pi0_paligemma.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
|
@ -518,32 +511,6 @@ class PI0FAST(nn.Module):
|
|||
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
# def configure_peft(self, model):
|
||||
# self.peft_method = self.config.peft_method
|
||||
# if "lora" in self.peft_method:
|
||||
# peft_config = self.config.peft_config
|
||||
# target_modules = peft_config.target_modules
|
||||
# if not isinstance(target_modules, list):
|
||||
# target_modules = target_modules.split(",")
|
||||
# lora_config = LoraConfig(
|
||||
# task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
|
||||
# r=peft_config.r, # The rank of the low-rank adaptation
|
||||
# lora_alpha=peft_config.lora_alpha, # Scaling factor
|
||||
# lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
|
||||
# target_modules=target_modules, # The components where LoRA is applied
|
||||
# )
|
||||
# self.lora_config = lora_config
|
||||
# model = get_peft_model(model, lora_config)
|
||||
# for name, param in model.named_parameters():
|
||||
# if (
|
||||
# "lora" in name
|
||||
# ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
|
||||
# param.requires_grad = True
|
||||
# else:
|
||||
# param.requires_grad = False
|
||||
|
||||
# return model
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
|
@ -554,11 +521,7 @@ class PI0FAST(nn.Module):
|
|||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||
images = []
|
||||
img_masks = []
|
||||
img_keys = sorted(self.config.image_features.keys(), key=lambda k: IMAGES_ORDER.get(k, float("inf")))
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
# present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")))
|
||||
|
||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
|
@ -566,7 +529,7 @@ class PI0FAST(nn.Module):
|
|||
|
||||
# Preprocess image features present in the batch
|
||||
num_empty_cameras = 0
|
||||
for key in img_keys:
|
||||
for key in self.image_keys:
|
||||
if key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
|
@ -586,14 +549,11 @@ class PI0FAST(nn.Module):
|
|||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
else:
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
continue
|
||||
img = torch.ones_like(img) * -1
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device) # FIXME(mshukor): similar to openpi, but should be zeros?
|
||||
# mask = torch.zeros(bsize, dtype=torch.bool, device=device)
|
||||
# mask = torch.zeros_like(img)
|
||||
# mask = torch.ones_like(mask)
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
num_empty_cameras += 1
|
||||
|
||||
images.append(img)
|
||||
|
@ -620,31 +580,21 @@ class PI0FAST(nn.Module):
|
|||
return fast_out
|
||||
|
||||
def create_token_type_ids(
|
||||
self, padded_mask: torch.Tensor, prefix_len: int, action_kw_len: int, state_len: torch.Tensor, mode: str = "prefix"
|
||||
self, padded_mask: torch.Tensor, prefix_len: int
|
||||
) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
# Suffix block (everything after prefix_len)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
if mode == "block_causal":
|
||||
# Start of state (only one position)
|
||||
start_state_mask = cumsum_mask == (prefix_len - (action_kw_len + state_len))
|
||||
# Start of action (only one position)
|
||||
start_action_mask = cumsum_mask >= (prefix_len - action_kw_len)
|
||||
# Combine the masks
|
||||
token_type_ids = suffix_mask | start_state_mask | start_action_mask
|
||||
else:
|
||||
token_type_ids = suffix_mask
|
||||
token_type_ids = suffix_mask
|
||||
return token_type_ids
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None, action_kw_to_prefix: bool = True):
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
# TODO remove hardcoded parameter (32)
|
||||
# discretized = F.pad(discretized, (0, max(0, 32 - discretized.shape[1])), value=0)[:, :32] # FIXME(mshukor): debug
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
prefix_texts = []
|
||||
|
@ -652,10 +602,7 @@ class PI0FAST(nn.Module):
|
|||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
if action_kw_to_prefix:
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\nAction:")
|
||||
else:
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
state_text.append(f"State: {state_str};\n")
|
||||
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
|
@ -665,11 +612,6 @@ class PI0FAST(nn.Module):
|
|||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
state_lens = self.paligemma_tokenizer(
|
||||
state_text, add_special_tokens=False, return_tensors="pt", padding="longest", truncation=False
|
||||
).attention_mask.sum(1)[:, None]
|
||||
action_kw_len = torch.tensor([2])[:, None] if action_kw_to_prefix else torch.tensor([0])[:, None] # corresponds to["Action:"]
|
||||
|
||||
if actions is not None:
|
||||
actions_norm = self.normalize_actions(actions)
|
||||
actions_pad = F.pad(
|
||||
|
@ -682,7 +624,7 @@ class PI0FAST(nn.Module):
|
|||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
|
||||
# replace action with 0 to pad tokens
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
self.pad_token_id,
|
||||
|
@ -693,21 +635,16 @@ class PI0FAST(nn.Module):
|
|||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
if action_kw_to_prefix:
|
||||
act_ids = torch.cat([act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([act_mask, eos_mask], dim=1)
|
||||
else:
|
||||
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
|
||||
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
|
||||
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
|
||||
#eos_mask = torch.ones_like(eos_token)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
|
||||
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
|
||||
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1) # act_ids already include postfix
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
@ -721,10 +658,10 @@ class PI0FAST(nn.Module):
|
|||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(
|
||||
dim=1
|
||||
) > prefix_lens # [:, None].to(padded_mask.device) # need a batch indicator of prefix lengths OR NOT
|
||||
) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(
|
||||
padded_mask=padded_mask, prefix_len=prefix_lens, action_kw_len=action_kw_len, state_len=state_lens, mode=self.config.attention_mode
|
||||
padded_mask=padded_mask, prefix_len=prefix_lens
|
||||
)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
|
@ -776,10 +713,10 @@ class PI0FAST(nn.Module):
|
|||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], action_kw_to_prefix=self.config.action_kw_to_prefix,
|
||||
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, att_masks, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
|
@ -842,6 +779,9 @@ class PI0FAST(nn.Module):
|
|||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
"""
|
||||
self.time_horizon = (
|
||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||
)
|
||||
|
@ -905,10 +845,9 @@ class PI0FAST(nn.Module):
|
|||
# Decode predicted output tokens
|
||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace(":", "").strip().split("|")[0].strip()
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
] # should work
|
||||
# TODO: for now let's use the processor tokenizer which encodes in the way we want (it is different from tusing the AutoTokenizer for some reasons)
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
|
@ -936,7 +875,7 @@ class PI0FAST(nn.Module):
|
|||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None, action_kw_to_prefix=self.config.action_kw_to_prefix)
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
|
@ -954,18 +893,16 @@ class PI0FAST(nn.Module):
|
|||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs, # No need for [prefix_embs, None]
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||
return actions
|
||||
|
||||
# TODO: remove? seems uneeded
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.pi0_paligemma.get_image_features(image)
|
||||
|
||||
|
@ -1059,18 +996,4 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
|||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
return padded_img
|
Loading…
Reference in New Issue