relaxed decoding and some cleanings
This commit is contained in:
parent
aa38f06c29
commit
eea7cc424a
|
@ -73,13 +73,14 @@ class PI0FASTConfig(PreTrainedConfig):
|
|||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
checkpoint_path: str = None
|
||||
load_paligemma_weights: bool = False
|
||||
|
||||
padding_side: str = "right"
|
||||
|
||||
precision: str = "bfloat16"
|
||||
grad_clip_norm: float = 1
|
||||
|
||||
relaxed_decoding: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
|
|
@ -256,16 +256,17 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||
|
||||
|
||||
def block_causal_update_causal_mask(
|
||||
# self,
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
# is_training: bool = None,
|
||||
attn_implementation: str = "eager",
|
||||
dtype: torch.dtype = "float32",
|
||||
):
|
||||
"""
|
||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
|
@ -471,17 +472,7 @@ class PI0FAST(nn.Module):
|
|||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
if config.load_paligemma_weights:
|
||||
print("Loading google/paligemma-3b-pt-224 weights ...")
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
device_map="cuda",
|
||||
torch_dtype=precision,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
else:
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
|
||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
|
@ -565,7 +556,7 @@ class PI0FAST(nn.Module):
|
|||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens1(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
return out
|
||||
|
||||
|
@ -621,7 +612,7 @@ class PI0FAST(nn.Module):
|
|||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
|
@ -774,6 +765,7 @@ class PI0FAST(nn.Module):
|
|||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
relaxed_decoding: bool = True,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
|
@ -798,18 +790,18 @@ class PI0FAST(nn.Module):
|
|||
try:
|
||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
if relaxed_decoding:
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
|
@ -848,13 +840,13 @@ class PI0FAST(nn.Module):
|
|||
for sample_tokens in cleaned_tokens
|
||||
] # something like this should be robust #looks good
|
||||
action_tokens = [
|
||||
self._act_tokens_to_paligemma_tokens1(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
]
|
||||
# returns the tensor of decoded actions per sample in a list
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim
|
||||
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
|
|
Loading…
Reference in New Issue