diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 68b1cc80303ad..7336c30bdda33 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -9,6 +9,7 @@ import mteb import numpy as np import pytest import requests +import torch from tests.models.utils import (EmbedModelInfo, RerankModelInfo, check_embeddings_close) @@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner, vllm_extra_kwargs=None, hf_model_callback=None, atol=MTEB_EMBED_TOL): + # A model family has many models with the same architecture, + # and we don't need to test each one. if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. pytest.skip("Skipping test.") - example_prompts = ["The chef prepared a delicious meal."] + # Test embed_dims, isnan and whether to use normalize + example_prompts = ["The chef prepared a delicious meal." * 1000] + # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides @@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert model_info.architecture in model_config.architectures + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype - vllm_outputs = vllm_model.embed(example_prompts) + # Test embed_dims, isnan and whether to use normalize + vllm_outputs = vllm_model.embed(example_prompts, + truncate_prompt_tokens=-1) + assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) + + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: with hf_runner(model_info.name, is_sentence_transformer=True, dtype="float32") as hf_model: + # e.g. setting default parameters for the encode method of hf_runner if hf_model_callback is not None: hf_model_callback(hf_model) @@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner, hf_model_callback=None, vllm_mteb_encoder=VllmMtebEncoder, atol=MTEB_RERANK_TOL): + # A model family has many models with the same architecture, + # and we don't need to test each one. if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. pytest.skip("Skipping test.") + # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides @@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert (model_info.architecture in model_config.architectures) + + # Score API is only enabled for num_labels == 1 assert model_config.hf_config.num_labels == 1 + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) @@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: st_main_score, st_dtype = mteb_test_rerank_models_hf( hf_runner, model_info.name, hf_model_callback) diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py index eaa8bfb84ffdd..fc888157b402a 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", architecture="GemmaForSequenceClassification", + mteb_score=0.33757, hf_overrides={ "architectures": ["GemmaForSequenceClassification"], diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e474420e3f04c..4f4673ac6e67a 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -745,7 +745,7 @@ class ModelConfig: self.pooler_config = self._init_pooler_config() - self.dtype = _get_and_verify_dtype( + self.dtype: torch.dtype = _get_and_verify_dtype( self.model, self.hf_config, self.dtype, @@ -1751,6 +1751,32 @@ class ModelConfig: # `llm as reranker` models defaults to not using pad_token. return getattr(self.hf_config, "use_pad_token", True) + @property + def head_dtype(self) -> torch.dtype: + """ + "head" refers to the last Linear layer(s) of an LLM, + such as the lm_head in a generation model, + or the score or classifier in a classification model. + + The default head_dtype based on runner_type.\n + - The pooling model defaults to using fp32 head, + you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n + - The generate model defaults to not using fp32 head, + you can use --hf-overrides '{"head_dtype": "float32"}' to enable it. + """ + head_dtype = _get_head_dtype(config=self.hf_config, + dtype=self.dtype, + runner_type=self.runner_type) + + if head_dtype not in current_platform.supported_dtypes: + logger.warning_once( + "The current platform does not support [%s] head dtype, " + "fallback to model dtype [%s].", head_dtype, self.dtype) + return self.dtype + + logger.debug_once("head dtype: %s", head_dtype) + return head_dtype + def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. @@ -2893,6 +2919,31 @@ def _get_and_verify_dtype( return torch_dtype +def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype, + runner_type: str) -> torch.dtype: + head_dtype: Optional[Union[str, + torch.dtype]] = getattr(config, "head_dtype", + None) + + if head_dtype == "model": + return dtype + elif isinstance(head_dtype, str): + head_dtype = head_dtype.lower() + if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {head_dtype!r}") + return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype] + elif isinstance(head_dtype, torch.dtype): + return head_dtype + elif head_dtype is None: + if torch.float32 not in current_platform.supported_dtypes: + return dtype + if runner_type == "pooling": + return torch.float32 + return dtype + else: + raise ValueError(f"Unknown dtype: {head_dtype}") + + def _get_and_verify_max_len( hf_config: PretrainedConfig, tokenizer_config: Optional[dict], diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index afe7ea7b83924..b571a8f866990 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union, cast +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn @@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation): class PoolerNormalize(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - x = F.normalize(pooled_data.float(), p=2, dim=-1) - return x.to(pooled_data.dtype) + return F.normalize(pooled_data, p=2, dim=-1) class PoolerMultiLabelClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) class PoolerClassify(PoolerActivation): @@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation): pooled_data.shape[-1]) if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) - return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) + return F.softmax(pooled_data, dim=-1) class LambdaPoolerActivation(PoolerActivation): @@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead): from vllm.model_executor.models.adapters import _load_st_projector vllm_config = get_current_vllm_config() - self.projector = _load_st_projector( + self.projector: Optional[nn.Module] = _load_st_projector( vllm_config.model_config) if vllm_config else None + self.head_dtype = vllm_config.model_config.head_dtype def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] + pooled_data = pooled_data.to(self.head_dtype) + # Apply ST projector if self.projector is not None: - projector = cast(nn.Module, self.projector) - - def _proj(x: torch.Tensor) -> torch.Tensor: - orig_dtype = x.dtype - y = projector(x.to(torch.float32)) - return y.to(orig_dtype) - - pooled_data = _proj(pooled_data) + pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] pooling_params = get_pooling_params(pooling_metadata) @@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerClassify(static_num_labels=False)) + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.head_dtype = vllm_config.model_config.head_dtype + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + + if isinstance(pooled_data, list): + pooled_data = [p.to(self.head_dtype) for p in pooled_data] + else: + pooled_data = pooled_data.to(self.head_dtype) + pooling_params = get_pooling_params(pooling_metadata) # for softmax @@ -641,6 +646,7 @@ class ClassifierPooler(Pooler): self.act_fn = act_fn or PoolerClassify() self.logit_bias: Optional[ float] = vllm_config.model_config.pooler_config.logit_bias + self.head_dtype = vllm_config.model_config.head_dtype def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -655,6 +661,8 @@ class ClassifierPooler(Pooler): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] + pooled_data = pooled_data.to(self.head_dtype) + if self.classifier is not None: pooled_data = self.classifier(pooled_data) # pooled_data shape: [batchsize, num_labels] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index bb96bc559200c..c189208fa075b 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: linear = nn.Linear(layer_config.get("in_features", 768), layer_config.get("out_features", 768), bias=layer_config.get("bias", True), - dtype=torch.float32) + dtype=model_config.head_dtype) if not _load_dense_weights(linear, folder, model_config): continue @@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) - return nn.Sequential(*layers).to(dtype=torch.float32) + return nn.Sequential(*layers).to(dtype=model_config.head_dtype) except Exception: logger.exception("ST projector loading failed") @@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str, if weight_key in state_dict: weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader) - weight_loader(linear.weight, - state_dict[weight_key].to(torch.float32)) + weight_loader(linear.weight, state_dict[weight_key]) bias_key = weight_key.replace("weight", "bias") if linear.bias is not None and bias_key in state_dict: bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader) - bias_loader(linear.bias, - state_dict[bias_key].to(torch.float32)) + bias_loader(linear.bias, state_dict[bias_key]) return True except Exception: logger.exception("Failed to load %s", filename) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 8f23439655ed7..c07e5364814ac 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, self.bert = BertPoolingModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 3be7e11d947d5..b758cbf28d893 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): self.new = GteNewModel(vllm_config=vllm_config, prefix=prefix, add_pooling_layer=True) - self.classifier = RowParallelLinear(config.hidden_size, - config.num_labels, - input_is_parallel=False, - bias=True, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "classifier"), - return_bias=False) + self.classifier = ReplicatedLinear( + config.hidden_size, + config.num_labels, + bias=True, + quant_config=quant_config, + params_dtype=vllm_config.model_config.head_dtype, + prefix=maybe_prefix(prefix, "classifier"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4446b5ab181c1..0f6521e44e6be 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module): config = vllm_config.model_config.hf_config self.transformer = GPT2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + self.score = nn.Linear(config.n_embd, + config.num_labels, + bias=False, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module): "encode": Pooler.for_encode(pooler_config), "classify": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module): position_ids=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - logits = self.score(hidden_states) - return logits + return hidden_states def _add_transformer_prefix( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 320e8d9d480c3..ce94328797ed6 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): delattr(self, attr) config = vllm_config.model_config.hf_config - self.v_head = RowParallelLinear( - config.hidden_size, - 1, - bias=False, - input_is_parallel=False, - prefix=maybe_prefix(prefix, "v_head"), - ) + self.head_dtype = vllm_config.model_config.head_dtype + + self.v_head = RowParallelLinear(config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + params_dtype=self.head_dtype, + prefix=maybe_prefix(prefix, "v_head"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - logits, _ = self.v_head(hidden_states) + hidden_states = hidden_states.to(self.head_dtype) + logits = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index aebd2cbe2e999..550fde17b6c53 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM): config.hidden_size, num_labels, bias=score_bias, - dtype=torch.float32, + dtype=vllm_config.model_config.head_dtype, ) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 140b0d1674728..f8c2a1e507a74 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -5,9 +5,9 @@ from typing import Optional import torch import torch.nn as nn -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -28,13 +28,17 @@ logger = init_logger(__name__) class JinaVLScorer(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() + config = model_config.hf_config + head_dtype = model_config.head_dtype self.dense = ColumnParallelLinear(config.hidden_size, config.hidden_size, + params_dtype=head_dtype, bias=True) self.out_proj = RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=head_dtype, bias=True) def forward(self, x, **kwargs): @@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")) - config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.score = JinaVLScorer(config) + self.score = JinaVLScorer(vllm_config.model_config) self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 776287589808a..1d5da3139de92 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.config = config self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 421b43563bade..2bd9d2b52628a 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.head_dtype = vllm_config.model_config.head_dtype self.score = nn.Sequential( ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config, + params_dtype=self.head_dtype, return_bias=False), nn.ReLU(), RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=self.head_dtype, quant_config=quant_config, return_bias=False), ) @@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = hidden_states.to(self.head_dtype) logits = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 2bfa51162910b..ba405be416876 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module): class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, config: RobertaConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + config = model_config.hf_config + head_dtype = model_config.head_dtype + self.dense = nn.Linear(config.hidden_size, + config.hidden_size, + dtype=head_dtype) + self.out_proj = nn.Linear(config.hidden_size, + config.num_labels, + dtype=head_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # CLSPool has already been applied in `pooling` @@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=RobertaEmbedding) - self.classifier = RobertaClassificationHead(config) + self.classifier = RobertaClassificationHead(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None