fix_processor

This commit is contained in:
wk 2025-03-11 14:39:21 +08:00
parent a13072f72f
commit ca3027cec2
2 changed files with 19 additions and 19 deletions

View File

@ -49,6 +49,7 @@ from transformers.utils import (
) )
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
from transformers import AutoModelForCausalLM
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
@ -2047,6 +2048,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
return model_inputs return model_inputs
from transformers import AutoModelForCausalLM
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA) AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)

View File

@ -30,15 +30,16 @@ class Qwen2VLAProcess:
len_views = images.shape[0] len_views = images.shape[0]
messages = self.construct_chat_data(len_views, raw_lang) messages = self.construct_chat_data(len_views, raw_lang)
data_dict = dict( data_dict = {
messages=messages, "messages":messages
) }
image_data = torch.chunk(images, len_views, 0) image_data = torch.chunk(images, len_views, 0)
images_list = [] images_list = []
for i, each in enumerate(image_data): for _i, each in enumerate(image_data):
img_pil = self.qwen2_image_preprocess(each) img_pil = self.qwen2_image_preprocess(each)
images_list.append(img_pil) images_list.append(img_pil)
@ -58,10 +59,7 @@ class Qwen2VLAProcess:
return model_inputs return model_inputs
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
if use_reasoning: answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
answer = reasoning + "Next action:" + "<|im_end|>"
else:
answer = "" + "<|im_end|>"
output_text = self.tokenizer(answer, padding=True, return_tensors="pt") output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
output_labels = output_text["input_ids"] output_labels = output_text["input_ids"]
@ -119,15 +117,16 @@ class Qwen2VLAProcess:
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),) attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
batch = dict( batch = {
input_ids=input_ids, "input_ids": input_ids,
attention_mask=attention_mask[0], "attention_mask": attention_mask[0],
labels=labels, "labels": labels,
image_grid_spatiotemporal=image_grid_spatiotemporal, "image_grid_spatiotemporal": image_grid_spatiotemporal,
pixel_values_videos=pixel_values_videos, "pixel_values_videos": pixel_values_videos,
video_grid_spatiotemporal=video_grid_spatiotemporal, "video_grid_spatiotemporal": video_grid_spatiotemporal,
pixel_values=pixel_values, "pixel_values": pixel_values,
) }
return batch return batch
def construct_chat_data(self, len_image, raw_lang): def construct_chat_data(self, len_image, raw_lang):
@ -138,7 +137,7 @@ class Qwen2VLAProcess:
}, },
] ]
for i in range(len_image): for _i in range(len_image):
messages[0]["content"].append( messages[0]["content"].append(
{ {
"type": "image", "type": "image",