fix_processor
This commit is contained in:
parent
a13072f72f
commit
ca3027cec2
|
@ -49,6 +49,7 @@ from transformers.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
|
from lerobot.common.policies.dexvla.fusion_modules import ActionProjector,FiLM
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
from .configuration_qwen2_vla import Qwen2VLAConfig, Qwen2VLVisionConfig
|
||||||
|
|
||||||
|
@ -2047,6 +2048,6 @@ class Qwen2VLForConditionalGenerationForVLA(Qwen2VLPreTrainedModel, GenerationMi
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
AutoModelForCausalLM.register(Qwen2VLAConfig, Qwen2VLForConditionalGenerationForVLA)
|
||||||
|
|
|
@ -30,15 +30,16 @@ class Qwen2VLAProcess:
|
||||||
len_views = images.shape[0]
|
len_views = images.shape[0]
|
||||||
messages = self.construct_chat_data(len_views, raw_lang)
|
messages = self.construct_chat_data(len_views, raw_lang)
|
||||||
|
|
||||||
data_dict = dict(
|
data_dict = {
|
||||||
messages=messages,
|
"messages":messages
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
image_data = torch.chunk(images, len_views, 0)
|
image_data = torch.chunk(images, len_views, 0)
|
||||||
|
|
||||||
images_list = []
|
images_list = []
|
||||||
|
|
||||||
for i, each in enumerate(image_data):
|
for _i, each in enumerate(image_data):
|
||||||
img_pil = self.qwen2_image_preprocess(each)
|
img_pil = self.qwen2_image_preprocess(each)
|
||||||
images_list.append(img_pil)
|
images_list.append(img_pil)
|
||||||
|
|
||||||
|
@ -58,10 +59,7 @@ class Qwen2VLAProcess:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||||
if use_reasoning:
|
answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||||
answer = reasoning + "Next action:" + "<|im_end|>"
|
|
||||||
else:
|
|
||||||
answer = "" + "<|im_end|>"
|
|
||||||
|
|
||||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||||
output_labels = output_text["input_ids"]
|
output_labels = output_text["input_ids"]
|
||||||
|
@ -119,15 +117,16 @@ class Qwen2VLAProcess:
|
||||||
|
|
||||||
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
attention_mask = (input_ids.ne(self.tokenizer.pad_token_id),)
|
||||||
|
|
||||||
batch = dict(
|
batch = {
|
||||||
input_ids=input_ids,
|
"input_ids": input_ids,
|
||||||
attention_mask=attention_mask[0],
|
"attention_mask": attention_mask[0],
|
||||||
labels=labels,
|
"labels": labels,
|
||||||
image_grid_spatiotemporal=image_grid_spatiotemporal,
|
"image_grid_spatiotemporal": image_grid_spatiotemporal,
|
||||||
pixel_values_videos=pixel_values_videos,
|
"pixel_values_videos": pixel_values_videos,
|
||||||
video_grid_spatiotemporal=video_grid_spatiotemporal,
|
"video_grid_spatiotemporal": video_grid_spatiotemporal,
|
||||||
pixel_values=pixel_values,
|
"pixel_values": pixel_values,
|
||||||
)
|
}
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def construct_chat_data(self, len_image, raw_lang):
|
def construct_chat_data(self, len_image, raw_lang):
|
||||||
|
@ -138,7 +137,7 @@ class Qwen2VLAProcess:
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(len_image):
|
for _i in range(len_image):
|
||||||
messages[0]["content"].append(
|
messages[0]["content"].append(
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
|
|
Loading…
Reference in New Issue