mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[New Model]: google/embeddinggemma-300m (#24318)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
53b19ccdd5
commit
6d6c6b05d3
@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
|||||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
|
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
|
||||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
|
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
|
||||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
|
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo
|
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||||
|
check_embeddings_close)
|
||||||
|
|
||||||
# Most embedding models on the STS12 task (See #17175):
|
# Most embedding models on the STS12 task (See #17175):
|
||||||
# - Model implementation and minor changes in tensor dtype
|
# - Model implementation and minor changes in tensor dtype
|
||||||
@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
model_info: EmbedModelInfo,
|
model_info: EmbedModelInfo,
|
||||||
vllm_extra_kwargs=None,
|
vllm_extra_kwargs=None,
|
||||||
hf_model_callback=None,
|
hf_model_callback=None,
|
||||||
atol=MTEB_RERANK_TOL):
|
atol=MTEB_EMBED_TOL):
|
||||||
if not model_info.enable_test:
|
if not model_info.enable_test:
|
||||||
# A model family has many models with the same architecture,
|
# A model family has many models with the same architecture,
|
||||||
# and we don't need to test each one.
|
# and we don't need to test each one.
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
|
example_prompts = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||||
MTEB_EMBED_TASKS)
|
MTEB_EMBED_TASKS)
|
||||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||||
|
vllm_outputs = vllm_model.embed(example_prompts)
|
||||||
|
|
||||||
if model_info.mteb_score is None:
|
if model_info.mteb_score is None:
|
||||||
with hf_runner(model_info.name,
|
with hf_runner(model_info.name,
|
||||||
@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
|
|
||||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||||
st_dtype = next(hf_model.model.parameters()).dtype
|
st_dtype = next(hf_model.model.parameters()).dtype
|
||||||
|
|
||||||
|
# Test embed_dims and whether to use normalize
|
||||||
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_outputs,
|
||||||
|
embeddings_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
st_main_score = model_info.mteb_score
|
st_main_score = model_info.mteb_score
|
||||||
st_dtype = "Constant"
|
st_dtype = "Constant"
|
||||||
|
|||||||
@ -2,7 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
||||||
|
LASTPoolingEmbedModelInfo)
|
||||||
from .mteb_utils import mteb_test_embed_models
|
from .mteb_utils import mteb_test_embed_models
|
||||||
|
|
||||||
# ST models with projector (Dense) layers
|
# ST models with projector (Dense) layers
|
||||||
@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [
|
|||||||
mteb_score=0.688611955,
|
mteb_score=0.688611955,
|
||||||
enable_test=True,
|
enable_test=True,
|
||||||
),
|
),
|
||||||
|
LASTPoolingEmbedModelInfo("google/embeddinggemma-300m",
|
||||||
|
architecture="Gemma3TextModel",
|
||||||
|
mteb_score=0.7473819294684156,
|
||||||
|
enable_test=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
|
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
|
||||||
|
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
|
||||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
|||||||
@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
||||||
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||||
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||||
|
"gemma3_text":
|
||||||
|
"Numerical instability. Please use bfloat16 or float32 instead.",
|
||||||
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||||
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
|||||||
if not dense_modules:
|
if not dense_modules:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
module = dense_modules[0]
|
layers = []
|
||||||
folder = module.get("path", "")
|
for module in dense_modules:
|
||||||
|
folder = module.get("path", "")
|
||||||
|
|
||||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||||
model_config.revision)
|
model_config.revision)
|
||||||
if not layer_config:
|
if not layer_config:
|
||||||
return None
|
continue
|
||||||
|
|
||||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||||
layer_config.get("out_features", 768),
|
layer_config.get("out_features", 768),
|
||||||
bias=layer_config.get("bias", True),
|
bias=layer_config.get("bias", True),
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
||||||
if _load_dense_weights(linear, folder, model_config):
|
if not _load_dense_weights(linear, folder, model_config):
|
||||||
layers = [linear]
|
continue
|
||||||
|
|
||||||
|
layers.append(linear)
|
||||||
if act_name := layer_config.get("activation_function"):
|
if act_name := layer_config.get("activation_function"):
|
||||||
layers.append(get_act_fn(act_name))
|
layers.append(get_act_fn(act_name))
|
||||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("ST projector loading failed")
|
logger.exception("ST projector loading failed")
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3TextModelConfig:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||||
|
|
||||||
|
|
||||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"GteModel": SnowflakeGteNewModelConfig,
|
"GteModel": SnowflakeGteNewModelConfig,
|
||||||
"GteNewModel": GteNewModelConfig,
|
"GteNewModel": GteNewModelConfig,
|
||||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||||
|
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||||
"NomicBertModel": NomicBertModelConfig,
|
"NomicBertModel": NomicBertModelConfig,
|
||||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma3TextConfig
|
from transformers import Gemma3TextConfig
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
|
|||||||
rope_scaling=self.rope_scaling,
|
rope_scaling=self.rope_scaling,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the attention.
|
if getattr(config, "is_causal", True):
|
||||||
self.attn = Attention(self.num_heads,
|
attn_type = AttentionType.DECODER
|
||||||
self.head_dim,
|
else:
|
||||||
self.scaling,
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
cache_config=cache_config,
|
attn_cls = (EncoderOnlyAttention
|
||||||
quant_config=quant_config,
|
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||||
logits_soft_cap=attn_logits_soft_cap,
|
|
||||||
per_layer_sliding_window=sliding_window,
|
self.attn = attn_cls(self.num_heads,
|
||||||
prefix=f"{prefix}.attn")
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
attn_type=attn_type,
|
||||||
|
logits_soft_cap=attn_logits_soft_cap,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
|
|||||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
|
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||||
"GritLM": ("gritlm", "GritLM"),
|
"GritLM": ("gritlm", "GritLM"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user