From ca3027cec20b309adb7d717610597d5633b4fd55 Mon Sep 17 00:00:00 2001 From: wk Date: Tue, 11 Mar 2025 14:39:21 +0800 Subject: [PATCH] fix_processor --- .../dexvla/qwe2_vla/modeling_qwen2_vla.py | 3 +- .../policies/dexvla/robot_data_processor.py | 35 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py index efbc3c3a..da47cd1b 100644 --- a/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py +++ b/lerobot/common/policies/dexvla/qwe2_vla/modeling_qwen2_vla.py @@ -49,6 +49,7 @@ from transformers.utils import ( ) from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM +from transformers import AutoModelForCausalLM from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig @@ -2047,6 +2048,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi return model_inputs -from transformers import AutoModelForCausalLM + AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index 4b75c439..db350165 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -30,15 +30,16 @@ class Qwen2VLAProcess: len_views = images.shape[0] messages = self.construct_chat_data(len_views, raw_lang) - data_dict = dict( - messages=messages, - ) + data_dict = { + "messages":messages + } + image_data = torch.chunk(images, len_views, 0) images_list = [] - for i, each in enumerate(image_data): + for _i, each in enumerate(image_data): img_pil = self.qwen2_image_preprocess(each) images_list.append(img_pil) @@ -58,10 +59,7 @@ class Qwen2VLAProcess: return model_inputs input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 - if use_reasoning: - answer = reasoning + "Next action:" + "<|im_end|>" - else: - answer = "" + "<|im_end|>" + answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>" output_text = self.tokenizer(answer, padding=True, return_tensors="pt") output_labels = output_text["input_ids"] @@ -119,15 +117,16 @@ class Qwen2VLAProcess: attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) - batch = dict( - input_ids=input_ids, - attention_mask=attention_mask[0], - labels=labels, - image_grid_spatiotemporal=image_grid_spatiotemporal, - pixel_values_videos=pixel_values_videos, - video_grid_spatiotemporal=video_grid_spatiotemporal, - pixel_values=pixel_values, - ) + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask[0], + "labels": labels, + "image_grid_spatiotemporal": image_grid_spatiotemporal, + "pixel_values_videos": pixel_values_videos, + "video_grid_spatiotemporal": video_grid_spatiotemporal, + "pixel_values": pixel_values, + } + return batch def construct_chat_data(self, len_image, raw_lang): @@ -138,7 +137,7 @@ class Qwen2VLAProcess: }, ] - for i in range(len_image): + for _i in range(len_image): messages[0]["content"].append( { "type": "image",