relaxed decoding and some cleanings
This commit is contained in:
parent
aa38f06c29
commit
eea7cc424a
|
@ -73,13 +73,14 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||||
scheduler_decay_lr: float = 2.5e-6
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
checkpoint_path: str = None
|
checkpoint_path: str = None
|
||||||
load_paligemma_weights: bool = False
|
|
||||||
|
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
|
|
||||||
precision: str = "bfloat16"
|
precision: str = "bfloat16"
|
||||||
grad_clip_norm: float = 1
|
grad_clip_norm: float = 1
|
||||||
|
|
||||||
|
relaxed_decoding: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
|
|
@ -256,16 +256,17 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||||
|
|
||||||
|
|
||||||
def block_causal_update_causal_mask(
|
def block_causal_update_causal_mask(
|
||||||
# self,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
cache_position=None,
|
cache_position=None,
|
||||||
input_tensor=None,
|
input_tensor=None,
|
||||||
# is_training: bool = None,
|
|
||||||
attn_implementation: str = "eager",
|
attn_implementation: str = "eager",
|
||||||
dtype: torch.dtype = "float32",
|
dtype: torch.dtype = "float32",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||||
|
"""
|
||||||
if attn_implementation == "flash_attention_2":
|
if attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
@ -471,16 +472,6 @@ class PI0FAST(nn.Module):
|
||||||
"vision_use_head": False,
|
"vision_use_head": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if config.load_paligemma_weights:
|
|
||||||
print("Loading google/paligemma-3b-pt-224 weights ...")
|
|
||||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
||||||
"google/paligemma-3b-pt-224",
|
|
||||||
device_map="cuda",
|
|
||||||
torch_dtype=precision,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
attn_implementation="eager",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||||
|
|
||||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||||
|
@ -565,7 +556,7 @@ class PI0FAST(nn.Module):
|
||||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||||
|
|
||||||
def _act_tokens_to_paligemma_tokens1(self, tokens: torch.Tensor) -> torch.Tensor:
|
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -621,7 +612,7 @@ class PI0FAST(nn.Module):
|
||||||
act_ids = fast_out["input_ids"]
|
act_ids = fast_out["input_ids"]
|
||||||
act_mask = fast_out["attention_mask"].to(device)
|
act_mask = fast_out["attention_mask"].to(device)
|
||||||
|
|
||||||
act_ids = self._act_tokens_to_paligemma_tokens1(act_ids).to(device)
|
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||||
# Replace action with 0 to pad tokens
|
# Replace action with 0 to pad tokens
|
||||||
act_ids = torch.where(
|
act_ids = torch.where(
|
||||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||||
|
@ -774,6 +765,7 @@ class PI0FAST(nn.Module):
|
||||||
*,
|
*,
|
||||||
time_horizon: int | None = None,
|
time_horizon: int | None = None,
|
||||||
action_dim: int | None = None,
|
action_dim: int | None = None,
|
||||||
|
relaxed_decoding: bool = True,
|
||||||
) -> np.array:
|
) -> np.array:
|
||||||
"""
|
"""
|
||||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||||
|
@ -798,7 +790,7 @@ class PI0FAST(nn.Module):
|
||||||
try:
|
try:
|
||||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||||
|
if relaxed_decoding:
|
||||||
# Expected sequence length
|
# Expected sequence length
|
||||||
expected_seq_len = self.time_horizon * self.action_dim
|
expected_seq_len = self.time_horizon * self.action_dim
|
||||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||||
|
@ -848,13 +840,13 @@ class PI0FAST(nn.Module):
|
||||||
for sample_tokens in cleaned_tokens
|
for sample_tokens in cleaned_tokens
|
||||||
] # something like this should be robust #looks good
|
] # something like this should be robust #looks good
|
||||||
action_tokens = [
|
action_tokens = [
|
||||||
self._act_tokens_to_paligemma_tokens1(raw_action_token) for raw_action_token in raw_action_tokens
|
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||||
]
|
]
|
||||||
# returns the tensor of decoded actions per sample in a list
|
# returns the tensor of decoded actions per sample in a list
|
||||||
decoded_actions = [
|
decoded_actions = [
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
self.decode_actions_with_fast(
|
self.decode_actions_with_fast(
|
||||||
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim
|
tok.tolist(), time_horizon=action_horizon, action_dim=action_dim, relaxed_decoding=self.config.relaxed_decoding
|
||||||
),
|
),
|
||||||
device=tokens.device,
|
device=tokens.device,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
Loading…
Reference in New Issue