From eea7cc424abb0ed7050c659499b7f4b6ec337750 Mon Sep 17 00:00:00 2001 From: mshukor Date: Wed, 2 Apr 2025 10:38:14 +0200 Subject: [PATCH] relaxed decoding and some cleanings --- .../policies/pi0fast/configuration_pi0fast.py | 3 +- .../policies/pi0fast/modeling_pi0fast.py | 50 ++++++++----------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py index f56234ce..ecaaa2cb 100644 --- a/lerobot/common/policies/pi0fast/configuration_pi0fast.py +++ b/lerobot/common/policies/pi0fast/configuration_pi0fast.py @@ -73,13 +73,14 @@ class PI0FASTConfig(PreTrainedConfig): scheduler_decay_lr: float = 2.5e-6 checkpoint_path: str = None - load_paligemma_weights: bool = False padding_side: str = "right" precision: str = "bfloat16" grad_clip_norm: float = 1 + relaxed_decoding: bool = True + def __post_init__(self): super().__post_init__() diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index f525299b..440645a3 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -256,16 +256,17 @@ class PI0FASTPolicy(PreTrainedPolicy): def block_causal_update_causal_mask( - # self, attention_mask, token_type_ids=None, past_key_values=None, cache_position=None, input_tensor=None, - # is_training: bool = None, attn_implementation: str = "eager", dtype: torch.dtype = "float32", ): + """ + Update the causal mask during training and generation. It can be customized to different attention masks. + """ if attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -471,17 +472,7 @@ class PI0FAST(nn.Module): "vision_use_head": False, }, ) - if config.load_paligemma_weights: - print("Loading google/paligemma-3b-pt-224 weights ...") - self.pi0_paligemma = PaliGemmaForConditionalGeneration.from_pretrained( - "google/paligemma-3b-pt-224", - device_map="cuda", - torch_dtype=precision, - low_cpu_mem_usage=True, - attn_implementation="eager", - ) - else: - self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) + self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) self.pi0_paligemma.prepare_inputs_for_generation = partial( prepare_inputs_for_generation, self=self.pi0_paligemma @@ -565,7 +556,7 @@ class PI0FAST(nn.Module): maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 - def _act_tokens_to_paligemma_tokens1(self, tokens: torch.Tensor) -> torch.Tensor: + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens return out @@ -621,7 +612,7 @@ class PI0FAST(nn.Module): act_ids = fast_out["input_ids"] act_mask = fast_out["attention_mask"].to(device) - act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device) + act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) # Replace action with 0 to pad tokens act_ids = torch.where( act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, @@ -774,6 +765,7 @@ class PI0FAST(nn.Module): *, time_horizon: int | None = None, action_dim: int | None = None, + relaxed_decoding: bool = True, ) -> np.array: """ Adapt original decoding in FAST to always return actions instead of zeros. @@ -798,18 +790,18 @@ class PI0FAST(nn.Module): try: decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token - - # Expected sequence length - expected_seq_len = self.time_horizon * self.action_dim - diff = expected_seq_len - decoded_dct_coeff.shape[0] - # Apply truncation if too long - if diff < 0: - decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right - # Apply padding if too short - elif diff > 0: - decoded_dct_coeff = np.pad( - decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 - ) + if relaxed_decoding: + # Expected sequence length + expected_seq_len = self.time_horizon * self.action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + # Apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right + # Apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) assert decoded_dct_coeff.shape == ( @@ -848,13 +840,13 @@ class PI0FAST(nn.Module): for sample_tokens in cleaned_tokens ] # something like this should be robust #looks good action_tokens = [ - self._act_tokens_to_paligemma_tokens1(raw_action_token) for raw_action_token in raw_action_tokens + self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens ] # returns the tensor of decoded actions per sample in a list decoded_actions = [ torch.tensor( self.decode_actions_with_fast( - tok.tolist(), time_horizon=action_horizon, action_dim=action_dim + tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding ), device=tokens.device, ).squeeze(0)