Merge branch '2025_02_20_add_dexvla' of https://github.com/JayceWen/lerobot into 2025_02_20_add_dexvla
This commit is contained in:
commit
b83cb0ba89
|
@ -1099,9 +1099,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
|
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
dim=0, dtype=torch.int32
|
grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]
|
||||||
)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
|
|
|
@ -97,7 +97,9 @@ class Qwen2VLAProcess:
|
||||||
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
|
||||||
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
|
||||||
|
|
||||||
image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances])
|
image_grid_spatiotemporal = torch.stack(
|
||||||
|
[instances["image_grid_spatiotemporal"] for instances in instances]
|
||||||
|
)
|
||||||
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
|
||||||
pixel_values_videos = None
|
pixel_values_videos = None
|
||||||
video_grid_spatiotemporal = None
|
video_grid_spatiotemporal = None
|
||||||
|
@ -110,7 +112,9 @@ class Qwen2VLAProcess:
|
||||||
input_ids = torch.flip(input_ids, dims=[1])
|
input_ids = torch.flip(input_ids, dims=[1])
|
||||||
b = input_ids.shape[0]
|
b = input_ids.shape[0]
|
||||||
|
|
||||||
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2])
|
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
|
||||||
|
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
|
||||||
|
)
|
||||||
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
|
||||||
|
|
||||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||||
|
|
Loading…
Reference in New Issue