fix_processor
This commit is contained in:
parent
a13072f72f
commit
ca3027cec2
|
@ -49,6 +49,7 @@ from transformers.utils import (
|
|||
)
|
||||
|
||||
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
||||
|
||||
|
@ -2047,6 +2048,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
|||
return model_inputs
|
||||
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
||||
|
|
|
@ -30,15 +30,16 @@ class Qwen2VLAProcess:
|
|||
len_views = images.shape[0]
|
||||
messages = self.construct_chat_data(len_views, raw_lang)
|
||||
|
||||
data_dict = dict(
|
||||
messages=messages,
|
||||
)
|
||||
data_dict = {
|
||||
"messages":messages
|
||||
}
|
||||
|
||||
|
||||
image_data = torch.chunk(images, len_views, 0)
|
||||
|
||||
images_list = []
|
||||
|
||||
for i, each in enumerate(image_data):
|
||||
for _i, each in enumerate(image_data):
|
||||
img_pil = self.qwen2_image_preprocess(each)
|
||||
images_list.append(img_pil)
|
||||
|
||||
|
@ -58,10 +59,7 @@ class Qwen2VLAProcess:
|
|||
return model_inputs
|
||||
|
||||
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||
if use_reasoning:
|
||||
answer = reasoning + "Next action:" + "<|im_end|>"
|
||||
else:
|
||||
answer = "" + "<|im_end|>"
|
||||
answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||
|
||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||
output_labels = output_text["input_ids"]
|
||||
|
@ -119,15 +117,16 @@ class Qwen2VLAProcess:
|
|||
|
||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||
|
||||
batch = dict(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask[0],
|
||||
labels=labels,
|
||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_grid_spatiotemporal=video_grid_spatiotemporal,
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask[0],
|
||||
"labels": labels,
|
||||
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
def construct_chat_data(self, len_image, raw_lang):
|
||||
|
@ -138,7 +137,7 @@ class Qwen2VLAProcess:
|
|||
},
|
||||
]
|
||||
|
||||
for i in range(len_image):
|
||||
for _i in range(len_image):
|
||||
messages[0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
|
|
Loading…
Reference in New Issue