diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9db6f8036a73..bdb29aac333c 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | | `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3TextModel`C | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | | `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 7be1bba2ff69..68b1cc80303a 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -10,7 +10,8 @@ import numpy as np import pytest import requests -from tests.models.utils import EmbedModelInfo, RerankModelInfo +from tests.models.utils import (EmbedModelInfo, RerankModelInfo, + check_embeddings_close) # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner, model_info: EmbedModelInfo, vllm_extra_kwargs=None, hf_model_callback=None, - atol=MTEB_RERANK_TOL): + atol=MTEB_EMBED_TOL): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. pytest.skip("Skipping test.") + example_prompts = ["The chef prepared a delicious meal."] + vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype @@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner, vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype + vllm_outputs = vllm_model.embed(example_prompts) if model_info.mteb_score is None: with hf_runner(model_info.name, @@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner, st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_dtype = next(hf_model.model.parameters()).dtype + + # Test embed_dims and whether to use normalize + hf_outputs = hf_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) else: st_main_score = model_info.mteb_score st_dtype = "Constant" diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py index bafeb4060d80..9301e705c433 100644 --- a/tests/models/language/pooling/test_st_projector.py +++ b/tests/models/language/pooling/test_st_projector.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo) from .mteb_utils import mteb_test_embed_models # ST models with projector (Dense) layers @@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [ mteb_score=0.688611955, enable_test=True, ), + LASTPoolingEmbedModelInfo("google/embeddinggemma-300m", + architecture="Gemma3TextModel", + mteb_score=0.7473819294684156, + enable_test=True) ] diff --git a/tests/models/registry.py b/tests/models/registry.py index 38efb01341eb..c6ff50b5426e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 + "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 8bdc22acf380..c4434c37f4c7 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = { _FLOAT16_NOT_SUPPORTED_MODELS = { "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3_text": + "Numerical instability. Please use bfloat16 or float32 instead.", "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", } diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 50c2cd97f3d0..bb96bc559200 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: if not dense_modules: return None - module = dense_modules[0] - folder = module.get("path", "") + layers = [] + for module in dense_modules: + folder = module.get("path", "") - config_path = f"{folder}/config.json" if folder else "config.json" - layer_config = get_hf_file_to_dict(config_path, model_config.model, - model_config.revision) - if not layer_config: - return None + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + continue - linear = nn.Linear(layer_config.get("in_features", 768), - layer_config.get("out_features", 768), - bias=layer_config.get("bias", True), - dtype=torch.float32) + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=torch.float32) - if _load_dense_weights(linear, folder, model_config): - layers = [linear] + if not _load_dense_weights(linear, folder, model_config): + continue + + layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) - return nn.Sequential(*layers).to(dtype=torch.float32) - + return nn.Sequential(*layers).to(dtype=torch.float32) except Exception: logger.exception("ST projector loading failed") diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 8b76a54332f8..f38e7fc20220 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -24,6 +24,14 @@ class VerifyAndUpdateConfig: raise NotImplementedError +class Gemma3TextModelConfig: + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = not hf_config.use_bidirectional_attention + + class GteNewModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig, + "Gemma3TextModel": Gemma3TextModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 410c715d5241..1263e3049a14 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch import nn from transformers import Gemma3TextConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, @@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module): rope_scaling=self.rope_scaling, ) - # Initialize the attention. - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 38d300b03d2c..c522fcab7f33 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -155,6 +155,7 @@ _EMBEDDING_MODELS = { "BertModel": ("bert", "BertEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3TextModel": ("gemma3", "Gemma3Model"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"),