relaxed decoding and some cleanings

This commit is contained in:
mshukor 2025-04-02 10:38:14 +02:00
parent aa38f06c29
commit eea7cc424a
2 changed files with 23 additions and 30 deletions

View File

@ -73,13 +73,14 @@ class PI0FASTConfig(PreTrainedConfig):
scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None
load_paligemma_weights: bool = False
padding_side: str = "right"
precision: str = "bfloat16"
grad_clip_norm: float = 1
relaxed_decoding: bool = True
def __post_init__(self):
super().__post_init__()

View File

@ -256,16 +256,17 @@ class PI0FASTPolicy(PreTrainedPolicy):
def block_causal_update_causal_mask(
# self,
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
# is_training: bool = None,
attn_implementation: str = "eager",
dtype: torch.dtype = "float32",
):
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
@ -471,17 +472,7 @@ class PI0FAST(nn.Module):
"vision_use_head": False,
},
)
if config.load_paligemma_weights:
print("Loading google/paligemma-3b-pt-224 weights ...")
self.pi0_paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
"google/paligemma-3b-pt-224",
device_map="cuda",
torch_dtype=precision,
low_cpu_mem_usage=True,
attn_implementation="eager",
)
else:
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
@ -565,7 +556,7 @@ class PI0FAST(nn.Module):
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
def _act_tokens_to_paligemma_tokens1(self, tokens: torch.Tensor) -> torch.Tensor:
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
return out
@ -621,7 +612,7 @@ class PI0FAST(nn.Module):
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
@ -774,6 +765,7 @@ class PI0FAST(nn.Module):
*,
time_horizon: int | None = None,
action_dim: int | None = None,
relaxed_decoding: bool = True,
) -> np.array:
"""
Adapt original decoding in FAST to always return actions instead of zeros.
@ -798,18 +790,18 @@ class PI0FAST(nn.Module):
try:
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
if relaxed_decoding:
# Expected sequence length
expected_seq_len = self.time_horizon * self.action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# Apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
# Apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
@ -848,13 +840,13 @@ class PI0FAST(nn.Module):
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self._act_tokens_to_paligemma_tokens1(raw_action_token) for raw_action_token in raw_action_tokens
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding
),
device=tokens.device,
).squeeze(0)