clean and refactor pi0fast
This commit is contained in:
parent
7a45fa0fc1
commit
6e48e044d7
|
@ -72,7 +72,7 @@
|
||||||
0.95
|
0.95
|
||||||
],
|
],
|
||||||
"optimizer_eps": 1e-08,
|
"optimizer_eps": 1e-08,
|
||||||
"optimizer_weight_decay": 1e-10,
|
"optimizer_weight_decay": 1e-5,
|
||||||
"scheduler_warmup_steps": 1000,
|
"scheduler_warmup_steps": 1000,
|
||||||
"scheduler_decay_steps": 30000,
|
"scheduler_decay_steps": 30000,
|
||||||
"scheduler_decay_lr": 2.5e-06
|
"scheduler_decay_lr": 2.5e-06
|
||||||
|
|
|
@ -8,13 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PEFTConfig:
|
|
||||||
r: int = 4
|
|
||||||
lora_alpha: int = 16
|
|
||||||
lora_dropout: float = 0.1
|
|
||||||
target_modules: str = "q_proj,v_proj"
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi0fast")
|
@PreTrainedConfig.register_subclass("pi0fast")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -71,10 +64,10 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
freeze_lm_head: bool = True
|
freeze_lm_head: bool = True
|
||||||
|
|
||||||
# Training presets
|
# Training presets
|
||||||
optimizer_lr: float = 2.5e-5
|
optimizer_lr: float = 1e-4
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
optimizer_eps: float = 1e-8
|
optimizer_eps: float = 1e-8
|
||||||
optimizer_weight_decay: float = 1e-10
|
optimizer_weight_decay: float = 1e-5
|
||||||
|
|
||||||
scheduler_warmup_steps: int = 1_000
|
scheduler_warmup_steps: int = 1_000
|
||||||
scheduler_decay_steps: int = 30_000
|
scheduler_decay_steps: int = 30_000
|
||||||
|
@ -85,15 +78,8 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
|
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
|
|
||||||
# peft_method: str = ""
|
|
||||||
# peft_config: PEFTConfig = PEFTConfig()
|
|
||||||
|
|
||||||
precision: str = "bfloat16"
|
precision: str = "bfloat16"
|
||||||
attention_mode: str = "prefix"
|
grad_clip_norm: float = 1
|
||||||
|
|
||||||
action_kw_to_prefix: bool = True
|
|
||||||
|
|
||||||
# TODO: Add EMA
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
@ -110,10 +96,6 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
# TODO: implement value error
|
|
||||||
# if not self.image_features and not self.env_state_feature:
|
|
||||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
|
||||||
|
|
||||||
for i in range(self.empty_cameras):
|
for i in range(self.empty_cameras):
|
||||||
key = f"observation.images.empty_camera_{i}"
|
key = f"observation.images.empty_camera_{i}"
|
||||||
empty_camera = PolicyFeature(
|
empty_camera = PolicyFeature(
|
||||||
|
@ -128,6 +110,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
betas=self.optimizer_betas,
|
betas=self.optimizer_betas,
|
||||||
eps=self.optimizer_eps,
|
eps=self.optimizer_eps,
|
||||||
weight_decay=self.optimizer_weight_decay,
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.grad_clip_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_scheduler_preset(self):
|
def get_scheduler_preset(self):
|
||||||
|
|
|
@ -50,7 +50,6 @@ from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
# from peft import LoraConfig, TaskType, get_peft_model
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from scipy.fft import idct
|
from scipy.fft import idct
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
@ -58,7 +57,6 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
||||||
from transformers.cache_utils import HybridCache, StaticCache
|
from transformers.cache_utils import HybridCache, StaticCache
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
|
|
||||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
@ -175,7 +173,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||||
self.model = PI0FAST(config)
|
self.model = PI0FAST(config)
|
||||||
self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug
|
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -223,7 +220,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
"""
|
"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
if self.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
@ -242,7 +239,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
|
||||||
if self.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
actions = self._pi_aloha_encode_actions(actions)
|
actions = self._pi_aloha_encode_actions(actions)
|
||||||
|
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
|
@ -251,7 +248,7 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
if self.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
@ -275,8 +272,6 @@ def block_causal_update_causal_mask(
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
return None
|
return None
|
||||||
# dtype = self.pi0_paligemma.dtype
|
|
||||||
# is_training = is_training if is_training is not None else self.training
|
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
@ -398,7 +393,7 @@ def prepare_inputs_for_generation(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# position_ids in Paligemma are 1-indexed
|
# Position_ids in Paligemma are 1-indexed
|
||||||
if model_inputs.get("position_ids") is not None:
|
if model_inputs.get("position_ids") is not None:
|
||||||
model_inputs["position_ids"] += 1
|
model_inputs["position_ids"] += 1
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
|
@ -491,19 +486,17 @@ class PI0FAST(nn.Module):
|
||||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||||
)
|
)
|
||||||
# self.pi0_paligemma = self.configure_peft(pi0_paligemma)
|
|
||||||
# change important stuff in bf16
|
# change important stuff in bf16
|
||||||
params_to_change_dtype = [
|
params_to_change_dtype = [
|
||||||
"language_model",
|
"language_model",
|
||||||
"vision_tower",
|
"vision_tower",
|
||||||
"multi_modal",
|
"multi_modal",
|
||||||
]
|
]
|
||||||
print(f"Cast model params to {precision}")
|
|
||||||
for name, param in self.pi0_paligemma.named_parameters():
|
for name, param in self.pi0_paligemma.named_parameters():
|
||||||
if any(selector in name for selector in params_to_change_dtype):
|
if any(selector in name for selector in params_to_change_dtype):
|
||||||
param.data = param.data.to(dtype=torch_precision)
|
param.data = param.data.to(dtype=torch_precision)
|
||||||
self.set_requires_grad()
|
self.set_requires_grad()
|
||||||
|
self.image_keys = self.config.image_features.keys()
|
||||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||||
self.padding_side = self.config.padding_side
|
self.padding_side = self.config.padding_side
|
||||||
|
|
||||||
|
@ -518,32 +511,6 @@ class PI0FAST(nn.Module):
|
||||||
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
|
if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied
|
||||||
params.requires_grad = False
|
params.requires_grad = False
|
||||||
|
|
||||||
# def configure_peft(self, model):
|
|
||||||
# self.peft_method = self.config.peft_method
|
|
||||||
# if "lora" in self.peft_method:
|
|
||||||
# peft_config = self.config.peft_config
|
|
||||||
# target_modules = peft_config.target_modules
|
|
||||||
# if not isinstance(target_modules, list):
|
|
||||||
# target_modules = target_modules.split(",")
|
|
||||||
# lora_config = LoraConfig(
|
|
||||||
# task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
|
|
||||||
# r=peft_config.r, # The rank of the low-rank adaptation
|
|
||||||
# lora_alpha=peft_config.lora_alpha, # Scaling factor
|
|
||||||
# lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
|
|
||||||
# target_modules=target_modules, # The components where LoRA is applied
|
|
||||||
# )
|
|
||||||
# self.lora_config = lora_config
|
|
||||||
# model = get_peft_model(model, lora_config)
|
|
||||||
# for name, param in model.named_parameters():
|
|
||||||
# if (
|
|
||||||
# "lora" in name
|
|
||||||
# ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
|
|
||||||
# param.requires_grad = True
|
|
||||||
# else:
|
|
||||||
# param.requires_grad = False
|
|
||||||
|
|
||||||
# return model
|
|
||||||
|
|
||||||
def embed_tokens(self, tokens: torch.Tensor):
|
def embed_tokens(self, tokens: torch.Tensor):
|
||||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||||
|
|
||||||
|
@ -554,11 +521,7 @@ class PI0FAST(nn.Module):
|
||||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||||
images = []
|
images = []
|
||||||
img_masks = []
|
img_masks = []
|
||||||
img_keys = sorted(self.config.image_features.keys(), key=lambda k: IMAGES_ORDER.get(k, float("inf")))
|
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
||||||
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
||||||
# present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf")))
|
|
||||||
|
|
||||||
if len(present_img_keys) == 0:
|
if len(present_img_keys) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||||
|
@ -566,7 +529,7 @@ class PI0FAST(nn.Module):
|
||||||
|
|
||||||
# Preprocess image features present in the batch
|
# Preprocess image features present in the batch
|
||||||
num_empty_cameras = 0
|
num_empty_cameras = 0
|
||||||
for key in img_keys:
|
for key in self.image_keys:
|
||||||
if key in present_img_keys:
|
if key in present_img_keys:
|
||||||
img = batch[key]
|
img = batch[key]
|
||||||
|
|
||||||
|
@ -586,14 +549,11 @@ class PI0FAST(nn.Module):
|
||||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
if num_empty_cameras >= self.config.empty_cameras:
|
if num_empty_cameras >= self.config.empty_cameras:
|
||||||
break
|
continue
|
||||||
img = torch.ones_like(img) * -1
|
img = torch.ones_like(img) * -1
|
||||||
bsize = img.shape[0]
|
bsize = img.shape[0]
|
||||||
device = img.device
|
device = img.device
|
||||||
mask = torch.ones(bsize, dtype=torch.bool, device=device) # FIXME(mshukor): similar to openpi, but should be zeros?
|
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||||
# mask = torch.zeros(bsize, dtype=torch.bool, device=device)
|
|
||||||
# mask = torch.zeros_like(img)
|
|
||||||
# mask = torch.ones_like(mask)
|
|
||||||
num_empty_cameras += 1
|
num_empty_cameras += 1
|
||||||
|
|
||||||
images.append(img)
|
images.append(img)
|
||||||
|
@ -620,31 +580,21 @@ class PI0FAST(nn.Module):
|
||||||
return fast_out
|
return fast_out
|
||||||
|
|
||||||
def create_token_type_ids(
|
def create_token_type_ids(
|
||||||
self, padded_mask: torch.Tensor, prefix_len: int, action_kw_len: int, state_len: torch.Tensor, mode: str = "prefix"
|
self, padded_mask: torch.Tensor, prefix_len: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||||
# Compute cumulative sum mask
|
# Compute cumulative sum mask
|
||||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||||
# Suffix block (everything after prefix_len)
|
# Suffix block (everything after prefix_len)
|
||||||
suffix_mask = cumsum_mask > prefix_len
|
suffix_mask = cumsum_mask > prefix_len
|
||||||
if mode == "block_causal":
|
|
||||||
# Start of state (only one position)
|
|
||||||
start_state_mask = cumsum_mask == (prefix_len - (action_kw_len + state_len))
|
|
||||||
# Start of action (only one position)
|
|
||||||
start_action_mask = cumsum_mask >= (prefix_len - action_kw_len)
|
|
||||||
# Combine the masks
|
|
||||||
token_type_ids = suffix_mask | start_state_mask | start_action_mask
|
|
||||||
else:
|
|
||||||
token_type_ids = suffix_mask
|
token_type_ids = suffix_mask
|
||||||
return token_type_ids
|
return token_type_ids
|
||||||
|
|
||||||
def create_input_tokens(self, state, lang_text, actions=None, action_kw_to_prefix: bool = True):
|
def create_input_tokens(self, state, lang_text, actions=None):
|
||||||
bsize = state.shape[0]
|
bsize = state.shape[0]
|
||||||
device = state.device
|
device = state.device
|
||||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||||
discretized = torch.bucketize(state, bins) - 1
|
discretized = torch.bucketize(state, bins) - 1
|
||||||
# TODO remove hardcoded parameter (32)
|
|
||||||
# discretized = F.pad(discretized, (0, max(0, 32 - discretized.shape[1])), value=0)[:, :32] # FIXME(mshukor): debug
|
|
||||||
discretized = discretized[:, :32]
|
discretized = discretized[:, :32]
|
||||||
|
|
||||||
prefix_texts = []
|
prefix_texts = []
|
||||||
|
@ -652,9 +602,6 @@ class PI0FAST(nn.Module):
|
||||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||||
cleaned = txt.lower().strip().replace("_", " ")
|
cleaned = txt.lower().strip().replace("_", " ")
|
||||||
state_str = " ".join(str(val.item()) for val in disc)
|
state_str = " ".join(str(val.item()) for val in disc)
|
||||||
if action_kw_to_prefix:
|
|
||||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\nAction:")
|
|
||||||
else:
|
|
||||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||||
state_text.append(f"State: {state_str};\n")
|
state_text.append(f"State: {state_str};\n")
|
||||||
|
|
||||||
|
@ -665,11 +612,6 @@ class PI0FAST(nn.Module):
|
||||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||||
|
|
||||||
state_lens = self.paligemma_tokenizer(
|
|
||||||
state_text, add_special_tokens=False, return_tensors="pt", padding="longest", truncation=False
|
|
||||||
).attention_mask.sum(1)[:, None]
|
|
||||||
action_kw_len = torch.tensor([2])[:, None] if action_kw_to_prefix else torch.tensor([0])[:, None] # corresponds to["Action:"]
|
|
||||||
|
|
||||||
if actions is not None:
|
if actions is not None:
|
||||||
actions_norm = self.normalize_actions(actions)
|
actions_norm = self.normalize_actions(actions)
|
||||||
actions_pad = F.pad(
|
actions_pad = F.pad(
|
||||||
|
@ -682,7 +624,7 @@ class PI0FAST(nn.Module):
|
||||||
act_mask = fast_out["attention_mask"].to(device)
|
act_mask = fast_out["attention_mask"].to(device)
|
||||||
|
|
||||||
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
|
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
|
||||||
# replace action with 0 to pad tokens
|
# Replace action with 0 to pad tokens
|
||||||
act_ids = torch.where(
|
act_ids = torch.where(
|
||||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||||
self.pad_token_id,
|
self.pad_token_id,
|
||||||
|
@ -693,21 +635,16 @@ class PI0FAST(nn.Module):
|
||||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||||
).expand(bsize, -1)
|
).expand(bsize, -1)
|
||||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||||
if action_kw_to_prefix:
|
|
||||||
act_ids = torch.cat([act_ids, eos_token], dim=1)
|
|
||||||
act_mask = torch.cat([act_mask, eos_mask], dim=1)
|
|
||||||
else:
|
|
||||||
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
|
bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt')
|
||||||
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
|
bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device)
|
||||||
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
|
bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device)
|
||||||
#eos_mask = torch.ones_like(eos_token)
|
|
||||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||||
act_mask = act_mask.to(device)
|
act_mask = act_mask.to(device)
|
||||||
else:
|
else:
|
||||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1) # act_ids already include postfix
|
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||||
|
|
||||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||||
|
@ -721,10 +658,10 @@ class PI0FAST(nn.Module):
|
||||||
# define tensor of padding lengths
|
# define tensor of padding lengths
|
||||||
att_mask = (padded_mask != 0).cumsum(
|
att_mask = (padded_mask != 0).cumsum(
|
||||||
dim=1
|
dim=1
|
||||||
) > prefix_lens # [:, None].to(padded_mask.device) # need a batch indicator of prefix lengths OR NOT
|
) > prefix_lens
|
||||||
|
|
||||||
token_type_ids = self.create_token_type_ids(
|
token_type_ids = self.create_token_type_ids(
|
||||||
padded_mask=padded_mask, prefix_len=prefix_lens, action_kw_len=action_kw_len, state_len=state_lens, mode=self.config.attention_mode
|
padded_mask=padded_mask, prefix_len=prefix_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||||
|
@ -776,10 +713,10 @@ class PI0FAST(nn.Module):
|
||||||
images, img_masks = self.prepare_images(batch)
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(
|
padded_outs = self.create_input_tokens(
|
||||||
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], action_kw_to_prefix=self.config.action_kw_to_prefix,
|
state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION],
|
||||||
)
|
)
|
||||||
|
|
||||||
embs, pad_masks, att_masks, targets, loss_mask, token_type_ids = self.embed_inputs(
|
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||||
images,
|
images,
|
||||||
img_masks,
|
img_masks,
|
||||||
padded_outs["input_ids"],
|
padded_outs["input_ids"],
|
||||||
|
@ -842,6 +779,9 @@ class PI0FAST(nn.Module):
|
||||||
time_horizon: int | None = None,
|
time_horizon: int | None = None,
|
||||||
action_dim: int | None = None,
|
action_dim: int | None = None,
|
||||||
) -> np.array:
|
) -> np.array:
|
||||||
|
"""
|
||||||
|
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||||
|
"""
|
||||||
self.time_horizon = (
|
self.time_horizon = (
|
||||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||||
)
|
)
|
||||||
|
@ -905,10 +845,9 @@ class PI0FAST(nn.Module):
|
||||||
# Decode predicted output tokens
|
# Decode predicted output tokens
|
||||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||||
cleaned_tokens = [
|
cleaned_tokens = [
|
||||||
tokens_sequence.replace(":", "").strip().split("|")[0].strip()
|
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||||
for tokens_sequence in decoded_tokens
|
for tokens_sequence in decoded_tokens
|
||||||
] # should work
|
]
|
||||||
# TODO: for now let's use the processor tokenizer which encodes in the way we want (it is different from tusing the AutoTokenizer for some reasons)
|
|
||||||
raw_action_tokens = [
|
raw_action_tokens = [
|
||||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||||
for sample_tokens in cleaned_tokens
|
for sample_tokens in cleaned_tokens
|
||||||
|
@ -936,7 +875,7 @@ class PI0FAST(nn.Module):
|
||||||
# TODO: keep like this or move to the policy .forward
|
# TODO: keep like this or move to the policy .forward
|
||||||
images, img_masks = self.prepare_images(batch)
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None, action_kw_to_prefix=self.config.action_kw_to_prefix)
|
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||||
images,
|
images,
|
||||||
img_masks,
|
img_masks,
|
||||||
|
@ -954,18 +893,16 @@ class PI0FAST(nn.Module):
|
||||||
attention_mask=pad_masks,
|
attention_mask=pad_masks,
|
||||||
position_ids=prefix_position_ids,
|
position_ids=prefix_position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=embs, # No need for [prefix_embs, None]
|
inputs_embeds=embs,
|
||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
max_new_tokens=self.config.max_decoding_steps,
|
max_new_tokens=self.config.max_decoding_steps,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
)
|
)
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
# TODO: remove? seems uneeded
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
def embed_image(self, image: torch.Tensor):
|
||||||
return self.pi0_paligemma.get_image_features(image)
|
return self.pi0_paligemma.get_image_features(image)
|
||||||
|
|
||||||
|
@ -1060,17 +997,3 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||||
# pad on left and top of image
|
# pad on left and top of image
|
||||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
return padded_img
|
return padded_img
|
||||||
|
|
||||||
|
|
||||||
def pad_vector(vector, new_dim):
|
|
||||||
"""Can be (batch_size x sequence_length x features_dimension)
|
|
||||||
or (batch_size x features_dimension)
|
|
||||||
"""
|
|
||||||
if vector.shape[-1] == new_dim:
|
|
||||||
return vector
|
|
||||||
shape = list(vector.shape)
|
|
||||||
current_dim = shape[-1]
|
|
||||||
shape[-1] = new_dim
|
|
||||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
|
||||||
new_vector[..., :current_dim] = vector
|
|
||||||
return new_vector
|
|
||||||
|
|
Loading…
Reference in New Issue