85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from transformers import GemmaConfig, PaliGemmaConfig
|
|
|
|
|
|
def get_paligemma_config(precision: str):
|
|
config = {
|
|
"image_token_index": None,
|
|
"pad_token_id": 0,
|
|
"bos_token_id": 2,
|
|
"eos_token_id": 1,
|
|
}
|
|
|
|
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
|
|
|
|
image_size = 224 # image_sizes[variant]
|
|
patch_size = 14
|
|
num_image_tokens = (image_size**2) // (patch_size**2)
|
|
|
|
config["image_token_index"] = 257152
|
|
text_config = {
|
|
"vocab_size": 257152,
|
|
"num_hidden_layers": 18,
|
|
"num_key_value_heads": 1,
|
|
"head_dim": 256,
|
|
"torch_dtype": precision,
|
|
"hidden_size": 2048,
|
|
"hidden_activation": "gelu_pytorch_tanh",
|
|
"num_attention_heads": 8,
|
|
"intermediate_size": 16384,
|
|
"is_encoder_decoder": False,
|
|
}
|
|
vision_config = {
|
|
"torch_dtype": precision,
|
|
"image_size": image_size,
|
|
"patch_size": patch_size,
|
|
"num_image_tokens": num_image_tokens,
|
|
"hidden_size": 1152,
|
|
"intermediate_size": 4304,
|
|
"num_hidden_layers": 27,
|
|
"num_attention_heads": 16,
|
|
"projector_hidden_act": "gelu_fast",
|
|
"vision_use_head": False,
|
|
}
|
|
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
|
|
return final_config
|
|
|
|
|
|
def get_gemma_config(precision: str):
|
|
config = {
|
|
"image_token_index": None,
|
|
"pad_token_id": 0,
|
|
"bos_token_id": 2,
|
|
"eos_token_id": 1,
|
|
}
|
|
|
|
config["image_token_index"] = 257152
|
|
text_config = {
|
|
"vocab_size": 257152,
|
|
"num_hidden_layers": 18,
|
|
"num_key_value_heads": 1,
|
|
"head_dim": 256,
|
|
"torch_dtype": precision,
|
|
"hidden_size": 1024,
|
|
"hidden_activation": "gelu_pytorch_tanh",
|
|
"num_attention_heads": 8,
|
|
"intermediate_size": 4096,
|
|
"is_encoder_decoder": False,
|
|
}
|
|
final_config = GemmaConfig()
|
|
final_config.update(text_config)
|
|
return final_config
|