relaxed decoding and some cleanings

This commit is contained in:
mshukor 2025-04-02 10:38:14 +02:00
parent aa38f06c29
commit eea7cc424a
2 changed files with 23 additions and 30 deletions

View File

@ -73,13 +73,14 @@ class PI0FASTConfig(PreTrainedConfig):
scheduler_decay_lr: float = 2.5e-6 scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None checkpoint_path: str = None
load_paligemma_weights: bool = False
padding_side: str = "right" padding_side: str = "right"
precision: str = "bfloat16" precision: str = "bfloat16"
grad_clip_norm: float = 1 grad_clip_norm: float = 1
relaxed_decoding: bool = True
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()

View File

@ -256,16 +256,17 @@ class PI0FASTPolicy(PreTrainedPolicy):
def block_causal_update_causal_mask( def block_causal_update_causal_mask(
# self,
attention_mask, attention_mask,
token_type_ids=None, token_type_ids=None,
past_key_values=None, past_key_values=None,
cache_position=None, cache_position=None,
input_tensor=None, input_tensor=None,
# is_training: bool = None,
attn_implementation: str = "eager", attn_implementation: str = "eager",
dtype: torch.dtype = "float32", dtype: torch.dtype = "float32",
): ):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2": if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
@ -471,16 +472,6 @@ class PI0FAST(nn.Module):
"vision_use_head": False, "vision_use_head": False,
}, },
) )
if config.load_paligemma_weights:
print("Loading google/paligemma-3b-pt-224 weights ...")
self.pi0_paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
"google/paligemma-3b-pt-224",
device_map="cuda",
torch_dtype=precision,
low_cpu_mem_usage=True,
attn_implementation="eager",
)
else:
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial( self.pi0_paligemma.prepare_inputs_for_generation = partial(
@ -565,7 +556,7 @@ class PI0FAST(nn.Module):
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens1(self, tokens: torch.Tensor) -> torch.Tensor: def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out return out
@ -621,7 +612,7 @@ class PI0FAST(nn.Module):
act_ids = fast_out["input_ids"] act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device) act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device) act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens # Replace action with 0 to pad tokens
act_ids = torch.where( act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
@ -774,6 +765,7 @@ class PI0FAST(nn.Module):
*, *,
time_horizon: int | None = None, time_horizon: int | None = None,
action_dim: int | None = None, action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array: ) -> np.array:
""" """
Adapt original decoding in FAST to always return actions instead of zeros. Adapt original decoding in FAST to always return actions instead of zeros.
@ -798,7 +790,7 @@ class PI0FAST(nn.Module):
try: try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
if relaxed_decoding:
# Expected sequence length # Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0] diff = expected_seq_len - decoded_dct_coeff.shape[0]
@ -848,13 +840,13 @@ class PI0FAST(nn.Module):
for sample_tokens in cleaned_tokens for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good ] # something like this should be robust #looks good
action_tokens = [ action_tokens = [
self._act_tokens_to_paligemma_tokens1(raw_action_token) for raw_action_token in raw_action_tokens self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
] ]
# returns the tensor of decoded actions per sample in a list # returns the tensor of decoded actions per sample in a list
decoded_actions = [ decoded_actions = [
torch.tensor( torch.tensor(
self.decode_actions_with_fast( self.decode_actions_with_fast(
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding
), ),
device=tokens.device, device=tokens.device,
).squeeze(0) ).squeeze(0)