diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index aef41381..b96c1f3d 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -1099,9 +1099,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_spatiotemporal) - cu_seqlens = torch.repeat_interleave(grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) + cu_seqlens = torch.repeat_interleave( + grid_spatiotemporal[:, 1] * grid_spatiotemporal[:, 2], grid_spatiotemporal[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 81988998..4b75c439 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -97,7 +97,9 @@ class Qwen2VLAProcess: input_ids = [torch.flip(instance["input_ids"].squeeze(0), dims=[0]) for instance in instances] labels = [torch.flip(instance["labels"].squeeze(0), dims=[0]) for instance in instances] - image_grid_spatiotemporal = torch.stack([instances["image_grid_spatiotemporal"] for instances in instances]) + image_grid_spatiotemporal = torch.stack( + [instances["image_grid_spatiotemporal"] for instances in instances] + ) pixel_values = torch.stack([instances["pixel_values"] for instances in instances]) pixel_values_videos = None video_grid_spatiotemporal = None @@ -110,7 +112,9 @@ class Qwen2VLAProcess: input_ids = torch.flip(input_ids, dims=[1]) b = input_ids.shape[0] - image_grid_spatiotemporal = image_grid_spatiotemporal.reshape(b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2]) + image_grid_spatiotemporal = image_grid_spatiotemporal.reshape( + b * image_grid_spatiotemporal.shape[1], image_grid_spatiotemporal.shape[2] + ) pixel_values = pixel_values.reshape(b * pixel_values.shape[1], pixel_values.shape[2]) attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)