cyf_update

This commit is contained in:
wk 2025-03-11 13:18:08 +08:00
parent f7d664dcc0
commit 9f4d490423
5 changed files with 46 additions and 46 deletions

View File

@ -124,7 +124,7 @@ class DexVLAPolicy(PreTrainedPolicy):
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
image_grid_spatiotemporal=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
@ -132,7 +132,7 @@ class DexVLAPolicy(PreTrainedPolicy):
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
image_grid_spatiotemporal=image_grid_spatiotemporal,
is_eval=is_eval,
num_beams=1,
do_sample=False,
@ -180,7 +180,7 @@ class DexVLAPolicy(PreTrainedPolicy):
is_eval=True,
pixel_values=None,
attention_mask=None,
image_grid_thw=None,
image_grid_spatiotemporal=None,
):
input_ids = input_ids.to("cuda")
with torch.inference_mode():
@ -188,7 +188,7 @@ class DexVLAPolicy(PreTrainedPolicy):
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
image_grid_spatiotemporal=image_grid_spatiotemporal,
is_eval=is_eval,
tinyvla=True,
)

View File

@ -433,7 +433,7 @@ class ScaleDP(PreTrainedModel):
Tp = self.num_queries
action_dim = self.action_dim
# initialize action from Guassian noise
# initialize action from Gaussian noise
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
naction = noisy_action.to(dtype=hidden_states.dtype)

View File

@ -243,7 +243,7 @@ class Qwen2VLAConfig(PretrainedConfig):
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":

View File

@ -172,7 +172,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for spatiotemporal grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
@ -206,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Explanation:
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
@ -636,7 +636,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
@ -1066,9 +1066,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
def get_device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
def rot_pos_emb(self, grid_thw):
def rot_pos_emb(self, grid_spatiotemporal):
pos_ids = []
for t, h, w in grid_thw:
for t, h, w in grid_spatiotemporal:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
@ -1090,16 +1090,16 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
max_grid_size = grid_spatiotemporal[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, grid_spatiotemporal: torch.Tensor) -> torch.Tensor:
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
@ -1358,7 +1358,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1467,9 +1467,9 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
[`Qwen2VLImageProcessor`] for processing videos.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
@ -1531,8 +1531,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
def get_rope_index(
self,
input_ids: torch.LongTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -1541,7 +1541,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
@ -1565,9 +1565,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -1584,7 +1584,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
if image_grid_spatiotemporal is not None or video_grid_spatiotemporal is not None:
total_input_ids = input_ids
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
@ -1613,18 +1613,18 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
image_grid_spatiotemporal[image_index][0],
image_grid_spatiotemporal[image_index][1],
image_grid_spatiotemporal[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
video_grid_spatiotemporal[video_index][0],
video_grid_spatiotemporal[video_index][1],
video_grid_spatiotemporal[video_index][2],
)
video_index += 1
remain_videos -= 1
@ -1717,8 +1717,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
actions: Optional[torch.LongTensor] = None,
states: Optional[torch.FloatTensor] = None,
@ -1774,7 +1774,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
actions = actions.to(dtype=self.computed_type, device="cuda")
states = states.to(dtype=self.computed_type, device="cuda")
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
)
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda")
@ -1790,7 +1790,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = self.visual(pixel_values, grid_spatiotemporal=image_grid_spatiotemporal)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
@ -1808,7 +1808,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_embeds = self.visual(pixel_values_videos, grid_spatiotemporal=video_grid_spatiotemporal)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
@ -1917,7 +1917,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
del inputs_embeds
del labels
del pixel_values
del image_grid_thw
del image_grid_spatiotemporal
del actions
del states
return Qwen2VLCausalLMOutputWithPast(
@ -1969,8 +1969,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
image_grid_spatiotemporal=None,
video_grid_spatiotemporal=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@ -1988,7 +1988,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
@ -2040,8 +2040,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"image_grid_spatiotemporal": image_grid_spatiotemporal,
"video_grid_spatiotemporal": video_grid_spatiotemporal,
"rope_deltas": rope_deltas,
}
)

View File

@ -97,10 +97,10 @@ class Qwen2VLAProcess:
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances])
image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances])
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
pixel_values_videos = None
video_grid_thw = None
video_grid_spatiotemporal = None
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
labels = torch.flip(labels, dims=[1])
@ -110,7 +110,7 @@ class Qwen2VLAProcess:
input_ids = torch.flip(input_ids, dims=[1])
b = input_ids.shape[0]
image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2])
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2])
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
@ -119,9 +119,9 @@ class Qwen2VLAProcess:
input_ids=input_ids,
attention_mask=attention_mask[0],
labels=labels,
image_grid_thw=image_grid_thw,
image_grid_spatiotemporal=image_grid_spatiotemporal,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
video_grid_spatiotemporal=video_grid_spatiotemporal,
pixel_values=pixel_values,
)
return batch