cyf_update
This commit is contained in:
parent
f7d664dcc0
commit
9f4d490423
|
@ -124,7 +124,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_thw=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
|
@ -132,7 +132,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_thw=image_grid_thw,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
|
@ -180,7 +180,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
is_eval=True,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
image_grid_thw=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
):
|
||||
input_ids = input_ids.to("cuda")
|
||||
with torch.inference_mode():
|
||||
|
@ -188,7 +188,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
image_grid_thw=image_grid_thw,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
is_eval=is_eval,
|
||||
tinyvla=True,
|
||||
)
|
||||
|
|
|
@ -433,7 +433,7 @@ class ScaleDP(PreTrainedModel):
|
|||
Tp = self.num_queries
|
||||
action_dim = self.action_dim
|
||||
|
||||
# initialize action from Guassian noise
|
||||
# initialize action from Gaussian noise
|
||||
noisy_action = torch.randn((B, Tp, action_dim)).cuda()
|
||||
|
||||
naction = noisy_action.to(dtype=hidden_states.dtype)
|
||||
|
|
|
@ -243,7 +243,7 @@ class Qwen2VLAConfig(PretrainedConfig):
|
|||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
|
|
|
@ -172,7 +172,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
|||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for spatiotemporal grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
|
@ -206,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
|||
Explanation:
|
||||
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
||||
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
||||
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
||||
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
||||
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
||||
|
@ -636,7 +636,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -1066,9 +1066,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|||
def get_device(self) -> torch.device:
|
||||
return self.blocks[0].mlp.fc2.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
def rot_pos_emb(self, grid_spatiotemporal):
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
for t, h, w in grid_spatiotemporal:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
|
@ -1090,16 +1090,16 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
max_grid_size = grid_spatiotemporal[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, grid_spatiotemporal: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
@ -1358,7 +1358,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1467,9 +1467,9 @@ QWEN2_VL_INPUTS_DOCSTRING = r"""
|
|||
The tensors corresponding to the input videos. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
||||
[`Qwen2VLImageProcessor`] for processing videos.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
||||
The rope index difference between sequence length and multimodal rope.
|
||||
|
@ -1531,8 +1531,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
def get_rope_index(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
@ -1541,7 +1541,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
Explanation:
|
||||
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
||||
|
||||
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
||||
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
||||
Examples:
|
||||
input_ids: [T T T T T], here T is for text.
|
||||
temporal position_ids: [0, 1, 2, 3, 4]
|
||||
|
@ -1565,9 +1565,9 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
image_grid_spatiotemporal (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
video_grid_spatiotemporal (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
@ -1584,7 +1584,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
mrope_position_deltas = []
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
if image_grid_spatiotemporal is not None or video_grid_spatiotemporal is not None:
|
||||
total_input_ids = input_ids
|
||||
position_ids = torch.ones(
|
||||
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
||||
|
@ -1613,18 +1613,18 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
image_grid_spatiotemporal[image_index][0],
|
||||
image_grid_spatiotemporal[image_index][1],
|
||||
image_grid_spatiotemporal[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
video_grid_spatiotemporal[video_index][0],
|
||||
video_grid_spatiotemporal[video_index][1],
|
||||
video_grid_spatiotemporal[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
|
@ -1717,8 +1717,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
image_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||
video_grid_spatiotemporal: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
actions: Optional[torch.LongTensor] = None,
|
||||
states: Optional[torch.FloatTensor] = None,
|
||||
|
@ -1774,7 +1774,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
actions = actions.to(dtype=self.computed_type, device="cuda")
|
||||
states = states.to(dtype=self.computed_type, device="cuda")
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
|
||||
)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=self.computed_type, device="cuda")
|
||||
|
@ -1790,7 +1790,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_embeds = self.visual(pixel_values, grid_spatiotemporal=image_grid_spatiotemporal)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
|
@ -1808,7 +1808,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_spatiotemporal=video_grid_spatiotemporal)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
|
@ -1917,7 +1917,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
del inputs_embeds
|
||||
del labels
|
||||
del pixel_values
|
||||
del image_grid_thw
|
||||
del image_grid_spatiotemporal
|
||||
del actions
|
||||
del states
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
|
@ -1969,8 +1969,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
use_cache=True,
|
||||
pixel_values=None,
|
||||
pixel_values_videos=None,
|
||||
image_grid_thw=None,
|
||||
video_grid_thw=None,
|
||||
image_grid_spatiotemporal=None,
|
||||
video_grid_spatiotemporal=None,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
|
@ -1988,7 +1988,7 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
if attention_mask is not None and position_ids is None:
|
||||
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
input_ids, image_grid_spatiotemporal, video_grid_spatiotemporal, attention_mask
|
||||
)
|
||||
else:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
@ -2040,8 +2040,8 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"image_grid_thw": image_grid_thw,
|
||||
"video_grid_thw": video_grid_thw,
|
||||
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||
"rope_deltas": rope_deltas,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -97,10 +97,10 @@ class Qwen2VLAProcess:
|
|||
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
||||
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
||||
|
||||
image_grid_thw = torch.stack([instances["image_grid_thw"] for instances in instances])
|
||||
image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances])
|
||||
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
||||
pixel_values_videos = None
|
||||
video_grid_thw = None
|
||||
video_grid_spatiotemporal = None
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
||||
labels = torch.flip(labels, dims=[1])
|
||||
|
@ -110,7 +110,7 @@ class Qwen2VLAProcess:
|
|||
input_ids = torch.flip(input_ids, dims=[1])
|
||||
b = input_ids.shape[0]
|
||||
|
||||
image_grid_thw = image_grid_thw.reshape(b * image_grid_thw.shape[1], image_grid_thw.shape[2])
|
||||
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2])
|
||||
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
||||
|
||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||
|
@ -119,9 +119,9 @@ class Qwen2VLAProcess:
|
|||
input_ids=input_ids,
|
||||
attention_mask=attention_mask[0],
|
||||
labels=labels,
|
||||
image_grid_thw=image_grid_thw,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_grid_thw=video_grid_thw,
|
||||
video_grid_spatiotemporal=video_grid_spatiotemporal,
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
return batch
|
||||
|
|
Loading…
Reference in New Issue