Merge branch '2025_02_20_add_dexvla' of https://github.com/JayceWen/lerobot into 2025_02_20_add_dexvla

This commit is contained in:
wk 2025-03-11 14:02:35 +08:00
commit b83cb0ba89
2 changed files with 9 additions and 5 deletions

View File

@ -1099,9 +1099,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal)
cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = torch.repeat_interleave(
grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:

View File

@ -97,7 +97,9 @@ class Qwen2VLAProcess:
input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances]
labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances]
image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances])
image_grid_spatiotemporal = torch.stack(
[instances["image_grid_spatiotemporal"] for instances in instances]
)
pixel_values = torch.stack([instances["pixel_values"] for instances in instances])
pixel_values_videos = None
video_grid_spatiotemporal = None
@ -110,7 +112,9 @@ class Qwen2VLAProcess:
input_ids = torch.flip(input_ids, dims=[1])
b = input_ids.shape[0]
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2])
image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(
b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]
)
pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2])
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)