[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-09-09 22:29:50 +08:00 committed by GitHub
parent a55cf41a09
commit 19332c0479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 166 additions and 61 deletions

View File

@ -9,6 +9,7 @@ import mteb
import numpy as np import numpy as np
import pytest import pytest
import requests import requests
import torch
from tests.models.utils import (EmbedModelInfo, RerankModelInfo, from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
check_embeddings_close) check_embeddings_close)
@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs=None, vllm_extra_kwargs=None,
hf_model_callback=None, hf_model_callback=None,
atol=MTEB_EMBED_TOL): atol=MTEB_EMBED_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
example_prompts = ["The chef prepared a delicious meal."] # Test embed_dims, isnan and whether to use normalize
example_prompts = ["The chef prepared a delicious meal." * 1000]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None: if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
model_config = vllm_model.llm.llm_engine.model_config model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture: if model_info.architecture:
assert model_info.architecture in model_config.architectures assert model_info.architecture in model_config.architectures
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert (model_config._model_info.default_pooling_type == assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type) model_info.default_pooling_type)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS) MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
vllm_outputs = vllm_model.embed(example_prompts)
# Test embed_dims, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts,
truncate_prompt_tokens=-1)
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None: if model_info.mteb_score is None:
with hf_runner(model_info.name, with hf_runner(model_info.name,
is_sentence_transformer=True, is_sentence_transformer=True,
dtype="float32") as hf_model: dtype="float32") as hf_model:
# e.g. setting default parameters for the encode method of hf_runner
if hf_model_callback is not None: if hf_model_callback is not None:
hf_model_callback(hf_model) hf_model_callback(hf_model)
@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
hf_model_callback=None, hf_model_callback=None,
vllm_mteb_encoder=VllmMtebEncoder, vllm_mteb_encoder=VllmMtebEncoder,
atol=MTEB_RERANK_TOL): atol=MTEB_RERANK_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None: if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
model_config = vllm_model.llm.llm_engine.model_config model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture: if model_info.architecture:
assert (model_info.architecture in model_config.architectures) assert (model_info.architecture in model_config.architectures)
# Score API is only enabled for num_labels == 1
assert model_config.hf_config.num_labels == 1 assert model_config.hf_config.num_labels == 1
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert (model_config._model_info.default_pooling_type == assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type) model_info.default_pooling_type)
@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
languages=MTEB_RERANK_LANGS) languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype vllm_dtype = model_config.dtype
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None: if model_info.mteb_score is None:
st_main_score, st_dtype = mteb_test_rerank_models_hf( st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback) hf_runner, model_info.name, hf_model_callback)

View File

@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification", architecture="GemmaForSequenceClassification",
mteb_score=0.33757,
hf_overrides={ hf_overrides={
"architectures": "architectures":
["GemmaForSequenceClassification"], ["GemmaForSequenceClassification"],

View File

@ -745,7 +745,7 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config() self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype( self.dtype: torch.dtype = _get_and_verify_dtype(
self.model, self.model,
self.hf_config, self.hf_config,
self.dtype, self.dtype,
@ -1751,6 +1751,32 @@ class ModelConfig:
# `llm as reranker` models defaults to not using pad_token. # `llm as reranker` models defaults to not using pad_token.
return getattr(self.hf_config, "use_pad_token", True) return getattr(self.hf_config, "use_pad_token", True)
@property
def head_dtype(self) -> torch.dtype:
"""
"head" refers to the last Linear layer(s) of an LLM,
such as the lm_head in a generation model,
or the score or classifier in a classification model.
The default head_dtype based on runner_type.\n
- The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
"""
head_dtype = _get_head_dtype(config=self.hf_config,
dtype=self.dtype,
runner_type=self.runner_type)
if head_dtype not in current_platform.supported_dtypes:
logger.warning_once(
"The current platform does not support [%s] head dtype, "
"fallback to model dtype [%s].", head_dtype, self.dtype)
return self.dtype
logger.debug_once("head dtype: %s", head_dtype)
return head_dtype
def get_and_verify_max_len(self, max_model_len: int): def get_and_verify_max_len(self, max_model_len: int):
# Consider max_model_len in tokenizer_config only when # Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding. # pooling models use absolute position_embedding.
@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
return torch_dtype return torch_dtype
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
runner_type: str) -> torch.dtype:
head_dtype: Optional[Union[str,
torch.dtype]] = getattr(config, "head_dtype",
None)
if head_dtype == "model":
return dtype
elif isinstance(head_dtype, str):
head_dtype = head_dtype.lower()
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {head_dtype!r}")
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
elif isinstance(head_dtype, torch.dtype):
return head_dtype
elif head_dtype is None:
if torch.float32 not in current_platform.supported_dtypes:
return dtype
if runner_type == "pooling":
return torch.float32
return dtype
else:
raise ValueError(f"Unknown dtype: {head_dtype}")
def _get_and_verify_max_len( def _get_and_verify_max_len(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
tokenizer_config: Optional[dict], tokenizer_config: Optional[dict],

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from itertools import groupby from itertools import groupby
from typing import Callable, Optional, TypeVar, Union, cast from typing import Callable, Optional, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class PoolerNormalize(PoolerActivation): class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
x = F.normalize(pooled_data.float(), p=2, dim=-1) return F.normalize(pooled_data, p=2, dim=-1)
return x.to(pooled_data.dtype)
class PoolerMultiLabelClassify(PoolerActivation): class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation): class PoolerClassify(PoolerActivation):
@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data.shape[-1]) pooled_data.shape[-1])
if num_labels < 2: if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) return F.sigmoid(pooled_data)
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation): class LambdaPoolerActivation(PoolerActivation):
@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from vllm.model_executor.models.adapters import _load_st_projector from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.projector = _load_st_projector( self.projector: Optional[nn.Module] = _load_st_projector(
vllm_config.model_config) if vllm_config else None vllm_config.model_config) if vllm_config else None
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata): pooling_metadata: PoolingMetadata):
@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension] # pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector # Apply ST projector
if self.projector is not None: if self.projector is not None:
projector = cast(nn.Module, self.projector) pooled_data = self.projector(pooled_data)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
pooled_data = _proj(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension] # pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False)) super().__init__(activation=PoolerClassify(static_num_labels=False))
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata): pooling_metadata: PoolingMetadata):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
# for softmax # for softmax
@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
self.act_fn = act_fn or PoolerClassify() self.act_fn = act_fn or PoolerClassify()
self.logit_bias: Optional[ self.logit_bias: Optional[
float] = vllm_config.model_config.pooler_config.logit_bias float] = vllm_config.model_config.pooler_config.logit_bias
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"} return {"classify", "score"}
@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size] # pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None: if self.classifier is not None:
pooled_data = self.classifier(pooled_data) pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels] # pooled_data shape: [batchsize, num_labels]

View File

@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear = nn.Linear(layer_config.get("in_features", 768), linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768), layer_config.get("out_features", 768),
bias=layer_config.get("bias", True), bias=layer_config.get("bias", True),
dtype=torch.float32) dtype=model_config.head_dtype)
if not _load_dense_weights(linear, folder, model_config): if not _load_dense_weights(linear, folder, model_config):
continue continue
@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers.append(linear) layers.append(linear)
if act_name := layer_config.get("activation_function"): if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name)) layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32) return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
except Exception: except Exception:
logger.exception("ST projector loading failed") logger.exception("ST projector loading failed")
@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if weight_key in state_dict: if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader", weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(linear.weight, weight_loader(linear.weight, state_dict[weight_key])
state_dict[weight_key].to(torch.float32))
bias_key = weight_key.replace("weight", "bias") bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict: if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader", bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader) default_weight_loader)
bias_loader(linear.bias, bias_loader(linear.bias, state_dict[bias_key])
state_dict[bias_key].to(torch.float32))
return True return True
except Exception: except Exception:
logger.exception("Failed to load %s", filename) logger.exception("Failed to load %s", filename)

View File

@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.bert = BertPoolingModel(vllm_config=vllm_config, self.bert = BertPoolingModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding) embedding_class=BertEmbedding)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None

View File

@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.new = GteNewModel(vllm_config=vllm_config, self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
add_pooling_layer=True) add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size, self.classifier = ReplicatedLinear(
config.num_labels, config.hidden_size,
input_is_parallel=False, config.num_labels,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix( params_dtype=vllm_config.model_config.head_dtype,
prefix, "classifier"), prefix=maybe_prefix(prefix, "classifier"),
return_bias=False) return_bias=False)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None

View File

@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.transformer = GPT2Model(vllm_config=vllm_config, self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt2")) prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) self.score = nn.Linear(config.n_embd,
config.num_labels,
bias=False,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
"encode": "encode":
Pooler.for_encode(pooler_config), Pooler.for_encode(pooler_config),
"classify": "classify":
Pooler.for_classify(pooler_config, classifier=None), Pooler.for_classify(pooler_config, classifier=self.score),
}) })
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
position_ids=positions, position_ids=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors) intermediate_tensors=intermediate_tensors)
logits = self.score(hidden_states) return hidden_states
return logits
def _add_transformer_prefix( def _add_transformer_prefix(

View File

@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr(self, attr) delattr(self, attr)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear( self.head_dtype = vllm_config.model_config.head_dtype
config.hidden_size,
1, self.v_head = RowParallelLinear(config.hidden_size,
bias=False, 1,
input_is_parallel=False, bias=False,
prefix=maybe_prefix(prefix, "v_head"), input_is_parallel=False,
) params_dtype=self.head_dtype,
prefix=maybe_prefix(prefix, "v_head"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors, hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.v_head(hidden_states) hidden_states = hidden_states.to(self.head_dtype)
logits = self.v_head(hidden_states)
return logits return logits

View File

@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
config.hidden_size, config.hidden_size,
num_labels, num_labels,
bias=score_bias, bias=score_bias,
dtype=torch.float32, dtype=vllm_config.model_config.head_dtype,
) )
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -28,13 +28,17 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module): class JinaVLScorer(nn.Module):
def __init__(self, config: PretrainedConfig): def __init__(self, model_config: "ModelConfig"):
super().__init__() super().__init__()
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(config.hidden_size, self.dense = ColumnParallelLinear(config.hidden_size,
config.hidden_size, config.hidden_size,
params_dtype=head_dtype,
bias=True) bias=True)
self.out_proj = RowParallelLinear(config.hidden_size, self.out_proj = RowParallelLinear(config.hidden_size,
config.num_labels, config.num_labels,
params_dtype=head_dtype,
bias=True) bias=True)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "qwen2_vl")) prefix=maybe_prefix(prefix, "qwen2_vl"))
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.score = JinaVLScorer(config) self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler({ self.pooler = DispatchPooler({
"encode": "encode":
Pooler.for_encode(pooler_config), Pooler.for_encode(pooler_config),

View File

@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.config = config self.config = config
self.model = ModernBertModel(vllm_config=vllm_config, self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert")) prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
self.pooling = ModernBertPooler(config) self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config

View File

@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config, self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.head_dtype = vllm_config.model_config.head_dtype
self.score = nn.Sequential( self.score = nn.Sequential(
ColumnParallelLinear(config.hidden_size, ColumnParallelLinear(config.hidden_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
params_dtype=self.head_dtype,
return_bias=False), return_bias=False),
nn.ReLU(), nn.ReLU(),
RowParallelLinear(config.hidden_size, RowParallelLinear(config.hidden_size,
config.num_labels, config.num_labels,
params_dtype=self.head_dtype,
quant_config=quant_config, quant_config=quant_config,
return_bias=False), return_bias=False),
) )
@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors, hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
hidden_states = hidden_states.to(self.head_dtype)
logits = self.score(hidden_states) logits = self.score(hidden_states)
return logits return logits

View File

@ -8,7 +8,7 @@ import torch
from torch import nn from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler) DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
class RobertaClassificationHead(nn.Module): class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig): def __init__(self, model_config: "ModelConfig"):
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) config = model_config.hf_config
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) head_dtype = model_config.head_dtype
self.dense = nn.Linear(config.hidden_size,
config.hidden_size,
dtype=head_dtype)
self.out_proj = nn.Linear(config.hidden_size,
config.num_labels,
dtype=head_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling` # CLSPool has already been applied in `pooling`
@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.roberta = BertModel(vllm_config=vllm_config, self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding) embedding_class=RobertaEmbedding)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(vllm_config.model_config)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None