diff --git a/lerobot/common/policies/pi0fast/config.json b/lerobot/common/policies/pi0fast/config.json index 4ea2098b..5ac603aa 100644 --- a/lerobot/common/policies/pi0fast/config.json +++ b/lerobot/common/policies/pi0fast/config.json @@ -72,7 +72,7 @@ 0.95 ], "optimizer_eps": 1e-08, - "optimizer_weight_decay": 1e-10, + "optimizer_weight_decay": 1e-5, "scheduler_warmup_steps": 1000, "scheduler_decay_steps": 30000, "scheduler_decay_lr": 2.5e-06 diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py index 5cacc1a3..7c5d3db5 100644 --- a/lerobot/common/policies/pi0fast/configuration_pi0fast.py +++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py @@ -8,13 +8,6 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -@dataclass -class PEFTConfig: - r: int = 4 - lora_alpha: int = 16 - lora_dropout: float = 0.1 - target_modules: str = "q_proj,v_proj" - @PreTrainedConfig.register_subclass("pi0fast") @dataclass @@ -71,10 +64,10 @@ class PI0FASTConfig(PreTrainedConfig): freeze_lm_head: bool = True # Training presets - optimizer_lr: float = 2.5e-5 + optimizer_lr: float = 1e-4 optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-10 + optimizer_weight_decay: float = 1e-5 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 @@ -85,15 +78,8 @@ class PI0FASTConfig(PreTrainedConfig): padding_side: str = "right" - # peft_method: str = "" - # peft_config: PEFTConfig = PEFTConfig() - precision: str = "bfloat16" - attention_mode: str = "prefix" - - action_kw_to_prefix: bool = True - - # TODO: Add EMA + grad_clip_norm: float = 1 def __post_init__(self): super().__post_init__() @@ -110,10 +96,6 @@ class PI0FASTConfig(PreTrainedConfig): ) def validate_features(self) -> None: - # TODO: implement value error - # if not self.image_features and not self.env_state_feature: - # raise ValueError("You must provide at least one image or the environment state among the inputs.") - for i in range(self.empty_cameras): key = f"observation.images.empty_camera_{i}" empty_camera = PolicyFeature( @@ -128,6 +110,7 @@ class PI0FASTConfig(PreTrainedConfig): betas=self.optimizer_betas, eps=self.optimizer_eps, weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.grad_clip_norm, ) def get_scheduler_preset(self): diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 7e75ed64..85297b3d 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -50,7 +50,6 @@ from functools import partial import numpy as np import torch import torch.nn.functional as F # noqa: N812 -# from peft import LoraConfig, TaskType, get_peft_model from PIL import Image from scipy.fft import idct from torch import Tensor, nn @@ -58,7 +57,6 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING -from lerobot.common.constants import ACTION, OBS_ROBOT from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -175,7 +173,6 @@ class PI0FASTPolicy(PreTrainedPolicy): self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FAST(config) - self.adapt_to_pi_aloha = True #self.config.adapt_to_pi_aloha # FIXME(mshukor): debug self.reset() @@ -223,7 +220,7 @@ class PI0FASTPolicy(PreTrainedPolicy): """ self.eval() - if self.adapt_to_pi_aloha: + if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch = self.normalize_inputs(batch) @@ -242,7 +239,7 @@ class PI0FASTPolicy(PreTrainedPolicy): actions = self.unnormalize_outputs({"action": actions})["action"] - if self.adapt_to_pi_aloha: + if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue @@ -251,7 +248,7 @@ class PI0FASTPolicy(PreTrainedPolicy): return self._action_queue.popleft() def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - if self.adapt_to_pi_aloha: + if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) @@ -275,8 +272,6 @@ def block_causal_update_causal_mask( if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - # dtype = self.pi0_paligemma.dtype - # is_training = is_training if is_training is not None else self.training using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(dtype).min @@ -398,7 +393,7 @@ def prepare_inputs_for_generation( **kwargs, ) - # position_ids in Paligemma are 1-indexed + # Position_ids in Paligemma are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore @@ -491,19 +486,17 @@ class PI0FAST(nn.Module): self.pi0_paligemma.prepare_inputs_for_generation = partial( prepare_inputs_for_generation, self=self.pi0_paligemma ) - # self.pi0_paligemma = self.configure_peft(pi0_paligemma) # change important stuff in bf16 params_to_change_dtype = [ "language_model", "vision_tower", "multi_modal", ] - print(f"Cast model params to {precision}") for name, param in self.pi0_paligemma.named_parameters(): if any(selector in name for selector in params_to_change_dtype): param.data = param.data.to(dtype=torch_precision) self.set_requires_grad() - + self.image_keys = self.config.image_features.keys() self.ignore_index = self.pi0_paligemma.config.ignore_index self.padding_side = self.config.padding_side @@ -518,32 +511,6 @@ class PI0FAST(nn.Module): if any([k in name for k in ["embed_tokens"]]): # lm heads and embedding layer are tied params.requires_grad = False - # def configure_peft(self, model): - # self.peft_method = self.config.peft_method - # if "lora" in self.peft_method: - # peft_config = self.config.peft_config - # target_modules = peft_config.target_modules - # if not isinstance(target_modules, list): - # target_modules = target_modules.split(",") - # lora_config = LoraConfig( - # task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - # r=peft_config.r, # The rank of the low-rank adaptation - # lora_alpha=peft_config.lora_alpha, # Scaling factor - # lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers - # target_modules=target_modules, # The components where LoRA is applied - # ) - # self.lora_config = lora_config - # model = get_peft_model(model, lora_config) - # for name, param in model.named_parameters(): - # if ( - # "lora" in name - # ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer - # param.requires_grad = True - # else: - # param.requires_grad = False - - # return model - def embed_tokens(self, tokens: torch.Tensor): return self.pi0_paligemma.language_model.model.embed_tokens(tokens) @@ -554,11 +521,7 @@ class PI0FAST(nn.Module): """Preprocess LeRobot batch into Pi0 inputs""" images = [] img_masks = [] - img_keys = sorted(self.config.image_features.keys(), key=lambda k: IMAGES_ORDER.get(k, float("inf"))) - present_img_keys = [key for key in self.config.image_features if key in batch] - # missing_img_keys = [key for key in self.config.image_features if key not in batch] - # present_img_keys = sorted(present_img_keys, key=lambda k: IMAGES_ORDER.get(k, float("inf"))) - + present_img_keys = [key for key in self.image_keys if key in batch] if len(present_img_keys) == 0: raise ValueError( f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" @@ -566,7 +529,7 @@ class PI0FAST(nn.Module): # Preprocess image features present in the batch num_empty_cameras = 0 - for key in img_keys: + for key in self.image_keys: if key in present_img_keys: img = batch[key] @@ -586,14 +549,11 @@ class PI0FAST(nn.Module): mask = torch.ones(bsize, dtype=torch.bool, device=device) else: if num_empty_cameras >= self.config.empty_cameras: - break + continue img = torch.ones_like(img) * -1 bsize = img.shape[0] device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) # FIXME(mshukor): similar to openpi, but should be zeros? - # mask = torch.zeros(bsize, dtype=torch.bool, device=device) - # mask = torch.zeros_like(img) - # mask = torch.ones_like(mask) + mask = torch.ones(bsize, dtype=torch.bool, device=device) num_empty_cameras += 1 images.append(img) @@ -620,31 +580,21 @@ class PI0FAST(nn.Module): return fast_out def create_token_type_ids( - self, padded_mask: torch.Tensor, prefix_len: int, action_kw_len: int, state_len: torch.Tensor, mode: str = "prefix" + self, padded_mask: torch.Tensor, prefix_len: int ) -> torch.Tensor: token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) # Compute cumulative sum mask cumsum_mask = (padded_mask != 0).cumsum(dim=1) # Suffix block (everything after prefix_len) suffix_mask = cumsum_mask > prefix_len - if mode == "block_causal": - # Start of state (only one position) - start_state_mask = cumsum_mask == (prefix_len - (action_kw_len + state_len)) - # Start of action (only one position) - start_action_mask = cumsum_mask >= (prefix_len - action_kw_len) - # Combine the masks - token_type_ids = suffix_mask | start_state_mask | start_action_mask - else: - token_type_ids = suffix_mask + token_type_ids = suffix_mask return token_type_ids - def create_input_tokens(self, state, lang_text, actions=None, action_kw_to_prefix: bool = True): + def create_input_tokens(self, state, lang_text, actions=None): bsize = state.shape[0] device = state.device bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] discretized = torch.bucketize(state, bins) - 1 - # TODO remove hardcoded parameter (32) - # discretized = F.pad(discretized, (0, max(0, 32 - discretized.shape[1])), value=0)[:, :32] # FIXME(mshukor): debug discretized = discretized[:, :32] prefix_texts = [] @@ -652,10 +602,7 @@ class PI0FAST(nn.Module): for txt, disc in zip(lang_text, discretized, strict=False): cleaned = txt.lower().strip().replace("_", " ") state_str = " ".join(str(val.item()) for val in disc) - if action_kw_to_prefix: - prefix_texts.append(f"Task: {cleaned}, State: {state_str};\nAction:") - else: - prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") + prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") state_text.append(f"State: {state_str};\n") prefix_out = self.paligemma_tokenizer( @@ -665,11 +612,6 @@ class PI0FAST(nn.Module): prefix_mask = prefix_out["attention_mask"].to(device) prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() - state_lens = self.paligemma_tokenizer( - state_text, add_special_tokens=False, return_tensors="pt", padding="longest", truncation=False - ).attention_mask.sum(1)[:, None] - action_kw_len = torch.tensor([2])[:, None] if action_kw_to_prefix else torch.tensor([0])[:, None] # corresponds to["Action:"] - if actions is not None: actions_norm = self.normalize_actions(actions) actions_pad = F.pad( @@ -682,7 +624,7 @@ class PI0FAST(nn.Module): act_mask = fast_out["attention_mask"].to(device) act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device) - # replace action with 0 to pad tokens + # Replace action with 0 to pad tokens act_ids = torch.where( act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, self.pad_token_id, @@ -693,21 +635,16 @@ class PI0FAST(nn.Module): [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device ).expand(bsize, -1) eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) - if action_kw_to_prefix: - act_ids = torch.cat([act_ids, eos_token], dim=1) - act_mask = torch.cat([act_mask, eos_mask], dim=1) - else: - bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt') - bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device) - bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device) - #eos_mask = torch.ones_like(eos_token) - act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) - act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) + bos = self.paligemma_tokenizer('Action: ', add_special_tokens=False, return_tensors='pt') + bos_token = bos['input_ids'].expand(act_ids.shape[0],-1).to(device) + bos_mask = bos['attention_mask'].expand(act_ids.shape[0],-1).to(device) + act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) + act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) act_mask = act_mask.to(device) else: act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) - final_ids = torch.cat([prefix_ids, act_ids], dim=1) # act_ids already include postfix + final_ids = torch.cat([prefix_ids, act_ids], dim=1) final_mask = torch.cat([prefix_mask, act_mask], dim=1) batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} @@ -721,10 +658,10 @@ class PI0FAST(nn.Module): # define tensor of padding lengths att_mask = (padded_mask != 0).cumsum( dim=1 - ) > prefix_lens # [:, None].to(padded_mask.device) # need a batch indicator of prefix lengths OR NOT + ) > prefix_lens token_type_ids = self.create_token_type_ids( - padded_mask=padded_mask, prefix_len=prefix_lens, action_kw_len=action_kw_len, state_len=state_lens, mode=self.config.attention_mode + padded_mask=padded_mask, prefix_len=prefix_lens ) padded_output["padded_mask"] = padded_output.pop("attention_mask") @@ -776,10 +713,10 @@ class PI0FAST(nn.Module): images, img_masks = self.prepare_images(batch) padded_outs = self.create_input_tokens( - state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], action_kw_to_prefix=self.config.action_kw_to_prefix, + state=batch[OBS_ROBOT], lang_text=batch["task"], actions=batch[ACTION], ) - embs, pad_masks, att_masks, targets, loss_mask, token_type_ids = self.embed_inputs( + embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( images, img_masks, padded_outs["input_ids"], @@ -842,6 +779,9 @@ class PI0FAST(nn.Module): time_horizon: int | None = None, action_dim: int | None = None, ) -> np.array: + """ + Adapt original decoding in FAST to always return actions instead of zeros. + """ self.time_horizon = ( time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon ) @@ -905,10 +845,9 @@ class PI0FAST(nn.Module): # Decode predicted output tokens decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) cleaned_tokens = [ - tokens_sequence.replace(":", "").strip().split("|")[0].strip() + tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() for tokens_sequence in decoded_tokens - ] # should work - # TODO: for now let's use the processor tokenizer which encodes in the way we want (it is different from tusing the AutoTokenizer for some reasons) + ] raw_action_tokens = [ self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) for sample_tokens in cleaned_tokens @@ -936,7 +875,7 @@ class PI0FAST(nn.Module): # TODO: keep like this or move to the policy .forward images, img_masks = self.prepare_images(batch) - padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None, action_kw_to_prefix=self.config.action_kw_to_prefix) + padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None) embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( images, img_masks, @@ -954,18 +893,16 @@ class PI0FAST(nn.Module): attention_mask=pad_masks, position_ids=prefix_position_ids, past_key_values=None, - inputs_embeds=embs, # No need for [prefix_embs, None] + inputs_embeds=embs, use_cache=self.config.use_cache, max_new_tokens=self.config.max_decoding_steps, do_sample=False, num_beams=1, token_type_ids=token_type_ids, ) - # import ipdb; ipdb.set_trace() actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) return actions - # TODO: remove? seems uneeded def embed_image(self, image: torch.Tensor): return self.pi0_paligemma.get_image_features(image) @@ -1059,18 +996,4 @@ def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): # pad on left and top of image padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img - - -def pad_vector(vector, new_dim): - """Can be (batch_size x sequence_length x features_dimension) - or (batch_size x features_dimension) - """ - if vector.shape[-1] == new_dim: - return vector - shape = list(vector.shape) - current_dim = shape[-1] - shape[-1] = new_dim - new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) - new_vector[..., :current_dim] = vector - return new_vector + return padded_img \ No newline at end of file