mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 14:27:08 +08:00
[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
a55cf41a09
commit
19332c0479
@ -9,6 +9,7 @@ import mteb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||||
check_embeddings_close)
|
check_embeddings_close)
|
||||||
@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
vllm_extra_kwargs=None,
|
vllm_extra_kwargs=None,
|
||||||
hf_model_callback=None,
|
hf_model_callback=None,
|
||||||
atol=MTEB_EMBED_TOL):
|
atol=MTEB_EMBED_TOL):
|
||||||
|
# A model family has many models with the same architecture,
|
||||||
|
# and we don't need to test each one.
|
||||||
if not model_info.enable_test:
|
if not model_info.enable_test:
|
||||||
# A model family has many models with the same architecture,
|
|
||||||
# and we don't need to test each one.
|
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
example_prompts = ["The chef prepared a delicious meal."]
|
# Test embed_dims, isnan and whether to use normalize
|
||||||
|
example_prompts = ["The chef prepared a delicious meal." * 1000]
|
||||||
|
|
||||||
|
# Allow vllm to test using the given dtype, such as float32
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
|
# Allow vllm to test using hf_overrides
|
||||||
if model_info.hf_overrides is not None:
|
if model_info.hf_overrides is not None:
|
||||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|
||||||
|
# Confirm whether vllm is using the correct architecture
|
||||||
if model_info.architecture:
|
if model_info.architecture:
|
||||||
assert model_info.architecture in model_config.architectures
|
assert model_info.architecture in model_config.architectures
|
||||||
|
|
||||||
|
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||||
|
# relates to whether chunked prefill and prefix caching are enabled
|
||||||
assert (model_config._model_info.default_pooling_type ==
|
assert (model_config._model_info.default_pooling_type ==
|
||||||
model_info.default_pooling_type)
|
model_info.default_pooling_type)
|
||||||
|
|
||||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||||
MTEB_EMBED_TASKS)
|
MTEB_EMBED_TASKS)
|
||||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||||
vllm_outputs = vllm_model.embed(example_prompts)
|
|
||||||
|
|
||||||
|
# Test embed_dims, isnan and whether to use normalize
|
||||||
|
vllm_outputs = vllm_model.embed(example_prompts,
|
||||||
|
truncate_prompt_tokens=-1)
|
||||||
|
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
|
||||||
|
|
||||||
|
# Accelerate mteb test by setting
|
||||||
|
# SentenceTransformers mteb score to a constant
|
||||||
if model_info.mteb_score is None:
|
if model_info.mteb_score is None:
|
||||||
with hf_runner(model_info.name,
|
with hf_runner(model_info.name,
|
||||||
is_sentence_transformer=True,
|
is_sentence_transformer=True,
|
||||||
dtype="float32") as hf_model:
|
dtype="float32") as hf_model:
|
||||||
|
|
||||||
|
# e.g. setting default parameters for the encode method of hf_runner
|
||||||
if hf_model_callback is not None:
|
if hf_model_callback is not None:
|
||||||
hf_model_callback(hf_model)
|
hf_model_callback(hf_model)
|
||||||
|
|
||||||
@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
hf_model_callback=None,
|
hf_model_callback=None,
|
||||||
vllm_mteb_encoder=VllmMtebEncoder,
|
vllm_mteb_encoder=VllmMtebEncoder,
|
||||||
atol=MTEB_RERANK_TOL):
|
atol=MTEB_RERANK_TOL):
|
||||||
|
# A model family has many models with the same architecture,
|
||||||
|
# and we don't need to test each one.
|
||||||
if not model_info.enable_test:
|
if not model_info.enable_test:
|
||||||
# A model family has many models with the same architecture,
|
|
||||||
# and we don't need to test each one.
|
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
|
# Allow vllm to test using the given dtype, such as float32
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||||
|
|
||||||
|
# Allow vllm to test using hf_overrides
|
||||||
if model_info.hf_overrides is not None:
|
if model_info.hf_overrides is not None:
|
||||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|
||||||
|
# Confirm whether vllm is using the correct architecture
|
||||||
if model_info.architecture:
|
if model_info.architecture:
|
||||||
assert (model_info.architecture in model_config.architectures)
|
assert (model_info.architecture in model_config.architectures)
|
||||||
|
|
||||||
|
# Score API is only enabled for num_labels == 1
|
||||||
assert model_config.hf_config.num_labels == 1
|
assert model_config.hf_config.num_labels == 1
|
||||||
|
|
||||||
|
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||||
|
# relates to whether chunked prefill and prefix caching are enabled
|
||||||
assert (model_config._model_info.default_pooling_type ==
|
assert (model_config._model_info.default_pooling_type ==
|
||||||
model_info.default_pooling_type)
|
model_info.default_pooling_type)
|
||||||
|
|
||||||
@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
languages=MTEB_RERANK_LANGS)
|
languages=MTEB_RERANK_LANGS)
|
||||||
vllm_dtype = model_config.dtype
|
vllm_dtype = model_config.dtype
|
||||||
|
|
||||||
|
# Accelerate mteb test by setting
|
||||||
|
# SentenceTransformers mteb score to a constant
|
||||||
if model_info.mteb_score is None:
|
if model_info.mteb_score is None:
|
||||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||||
hf_runner, model_info.name, hf_model_callback)
|
hf_runner, model_info.name, hf_model_callback)
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
|||||||
RERANK_MODELS = [
|
RERANK_MODELS = [
|
||||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||||
architecture="GemmaForSequenceClassification",
|
architecture="GemmaForSequenceClassification",
|
||||||
|
mteb_score=0.33757,
|
||||||
hf_overrides={
|
hf_overrides={
|
||||||
"architectures":
|
"architectures":
|
||||||
["GemmaForSequenceClassification"],
|
["GemmaForSequenceClassification"],
|
||||||
|
|||||||
@ -745,7 +745,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.pooler_config = self._init_pooler_config()
|
self.pooler_config = self._init_pooler_config()
|
||||||
|
|
||||||
self.dtype = _get_and_verify_dtype(
|
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||||
self.model,
|
self.model,
|
||||||
self.hf_config,
|
self.hf_config,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
@ -1751,6 +1751,32 @@ class ModelConfig:
|
|||||||
# `llm as reranker` models defaults to not using pad_token.
|
# `llm as reranker` models defaults to not using pad_token.
|
||||||
return getattr(self.hf_config, "use_pad_token", True)
|
return getattr(self.hf_config, "use_pad_token", True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dtype(self) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
"head" refers to the last Linear layer(s) of an LLM,
|
||||||
|
such as the lm_head in a generation model,
|
||||||
|
or the score or classifier in a classification model.
|
||||||
|
|
||||||
|
The default head_dtype based on runner_type.\n
|
||||||
|
- The pooling model defaults to using fp32 head,
|
||||||
|
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
|
||||||
|
- The generate model defaults to not using fp32 head,
|
||||||
|
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
|
||||||
|
"""
|
||||||
|
head_dtype = _get_head_dtype(config=self.hf_config,
|
||||||
|
dtype=self.dtype,
|
||||||
|
runner_type=self.runner_type)
|
||||||
|
|
||||||
|
if head_dtype not in current_platform.supported_dtypes:
|
||||||
|
logger.warning_once(
|
||||||
|
"The current platform does not support [%s] head dtype, "
|
||||||
|
"fallback to model dtype [%s].", head_dtype, self.dtype)
|
||||||
|
return self.dtype
|
||||||
|
|
||||||
|
logger.debug_once("head dtype: %s", head_dtype)
|
||||||
|
return head_dtype
|
||||||
|
|
||||||
def get_and_verify_max_len(self, max_model_len: int):
|
def get_and_verify_max_len(self, max_model_len: int):
|
||||||
# Consider max_model_len in tokenizer_config only when
|
# Consider max_model_len in tokenizer_config only when
|
||||||
# pooling models use absolute position_embedding.
|
# pooling models use absolute position_embedding.
|
||||||
@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
|
|||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||||
|
runner_type: str) -> torch.dtype:
|
||||||
|
head_dtype: Optional[Union[str,
|
||||||
|
torch.dtype]] = getattr(config, "head_dtype",
|
||||||
|
None)
|
||||||
|
|
||||||
|
if head_dtype == "model":
|
||||||
|
return dtype
|
||||||
|
elif isinstance(head_dtype, str):
|
||||||
|
head_dtype = head_dtype.lower()
|
||||||
|
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||||
|
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||||
|
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||||
|
elif isinstance(head_dtype, torch.dtype):
|
||||||
|
return head_dtype
|
||||||
|
elif head_dtype is None:
|
||||||
|
if torch.float32 not in current_platform.supported_dtypes:
|
||||||
|
return dtype
|
||||||
|
if runner_type == "pooling":
|
||||||
|
return torch.float32
|
||||||
|
return dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||||
|
|
||||||
|
|
||||||
def _get_and_verify_max_len(
|
def _get_and_verify_max_len(
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
tokenizer_config: Optional[dict],
|
tokenizer_config: Optional[dict],
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import Callable, Optional, TypeVar, Union, cast
|
from typing import Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
|
|||||||
class PoolerNormalize(PoolerActivation):
|
class PoolerNormalize(PoolerActivation):
|
||||||
|
|
||||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||||
x = F.normalize(pooled_data.float(), p=2, dim=-1)
|
return F.normalize(pooled_data, p=2, dim=-1)
|
||||||
return x.to(pooled_data.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class PoolerMultiLabelClassify(PoolerActivation):
|
class PoolerMultiLabelClassify(PoolerActivation):
|
||||||
|
|
||||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
return F.sigmoid(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
class PoolerClassify(PoolerActivation):
|
class PoolerClassify(PoolerActivation):
|
||||||
@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
|
|||||||
pooled_data.shape[-1])
|
pooled_data.shape[-1])
|
||||||
|
|
||||||
if num_labels < 2:
|
if num_labels < 2:
|
||||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
return F.sigmoid(pooled_data)
|
||||||
|
|
||||||
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
|
return F.softmax(pooled_data, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class LambdaPoolerActivation(PoolerActivation):
|
class LambdaPoolerActivation(PoolerActivation):
|
||||||
@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
from vllm.model_executor.models.adapters import _load_st_projector
|
from vllm.model_executor.models.adapters import _load_st_projector
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.projector = _load_st_projector(
|
self.projector: Optional[nn.Module] = _load_st_projector(
|
||||||
vllm_config.model_config) if vllm_config else None
|
vllm_config.model_config) if vllm_config else None
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||||
pooling_metadata: PoolingMetadata):
|
pooling_metadata: PoolingMetadata):
|
||||||
@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
pooled_data = torch.stack(pooled_data)
|
pooled_data = torch.stack(pooled_data)
|
||||||
# pooled_data shape: [batchsize, hidden_dimension]
|
# pooled_data shape: [batchsize, hidden_dimension]
|
||||||
|
|
||||||
|
pooled_data = pooled_data.to(self.head_dtype)
|
||||||
|
|
||||||
# Apply ST projector
|
# Apply ST projector
|
||||||
if self.projector is not None:
|
if self.projector is not None:
|
||||||
projector = cast(nn.Module, self.projector)
|
pooled_data = self.projector(pooled_data)
|
||||||
|
|
||||||
def _proj(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
y = projector(x.to(torch.float32))
|
|
||||||
return y.to(orig_dtype)
|
|
||||||
|
|
||||||
pooled_data = _proj(pooled_data)
|
|
||||||
# pooled_data shape: [batchsize, embedding_dimension]
|
# pooled_data shape: [batchsize, embedding_dimension]
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||||
pooling_metadata: PoolingMetadata):
|
pooling_metadata: PoolingMetadata):
|
||||||
|
|
||||||
|
if isinstance(pooled_data, list):
|
||||||
|
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
|
||||||
|
else:
|
||||||
|
pooled_data = pooled_data.to(self.head_dtype)
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
# for softmax
|
# for softmax
|
||||||
@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
|
|||||||
self.act_fn = act_fn or PoolerClassify()
|
self.act_fn = act_fn or PoolerClassify()
|
||||||
self.logit_bias: Optional[
|
self.logit_bias: Optional[
|
||||||
float] = vllm_config.model_config.pooler_config.logit_bias
|
float] = vllm_config.model_config.pooler_config.logit_bias
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"classify", "score"}
|
return {"classify", "score"}
|
||||||
@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
|
|||||||
pooled_data = torch.stack(pooled_data)
|
pooled_data = torch.stack(pooled_data)
|
||||||
# pooled_data shape: [batchsize, hidden_size]
|
# pooled_data shape: [batchsize, hidden_size]
|
||||||
|
|
||||||
|
pooled_data = pooled_data.to(self.head_dtype)
|
||||||
|
|
||||||
if self.classifier is not None:
|
if self.classifier is not None:
|
||||||
pooled_data = self.classifier(pooled_data)
|
pooled_data = self.classifier(pooled_data)
|
||||||
# pooled_data shape: [batchsize, num_labels]
|
# pooled_data shape: [batchsize, num_labels]
|
||||||
|
|||||||
@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
|||||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||||
layer_config.get("out_features", 768),
|
layer_config.get("out_features", 768),
|
||||||
bias=layer_config.get("bias", True),
|
bias=layer_config.get("bias", True),
|
||||||
dtype=torch.float32)
|
dtype=model_config.head_dtype)
|
||||||
|
|
||||||
if not _load_dense_weights(linear, folder, model_config):
|
if not _load_dense_weights(linear, folder, model_config):
|
||||||
continue
|
continue
|
||||||
@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
|||||||
layers.append(linear)
|
layers.append(linear)
|
||||||
if act_name := layer_config.get("activation_function"):
|
if act_name := layer_config.get("activation_function"):
|
||||||
layers.append(get_act_fn(act_name))
|
layers.append(get_act_fn(act_name))
|
||||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("ST projector loading failed")
|
logger.exception("ST projector loading failed")
|
||||||
|
|
||||||
@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
|
|||||||
if weight_key in state_dict:
|
if weight_key in state_dict:
|
||||||
weight_loader = getattr(linear.weight, "weight_loader",
|
weight_loader = getattr(linear.weight, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(linear.weight,
|
weight_loader(linear.weight, state_dict[weight_key])
|
||||||
state_dict[weight_key].to(torch.float32))
|
|
||||||
|
|
||||||
bias_key = weight_key.replace("weight", "bias")
|
bias_key = weight_key.replace("weight", "bias")
|
||||||
if linear.bias is not None and bias_key in state_dict:
|
if linear.bias is not None and bias_key in state_dict:
|
||||||
bias_loader = getattr(linear.bias, "weight_loader",
|
bias_loader = getattr(linear.bias, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
bias_loader(linear.bias,
|
bias_loader(linear.bias, state_dict[bias_key])
|
||||||
state_dict[bias_key].to(torch.float32))
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to load %s", filename)
|
logger.exception("Failed to load %s", filename)
|
||||||
|
|||||||
@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "bert"),
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
embedding_class=BertEmbedding)
|
embedding_class=BertEmbedding)
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
dtype=vllm_config.model_config.head_dtype)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|||||||
@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self.new = GteNewModel(vllm_config=vllm_config,
|
self.new = GteNewModel(vllm_config=vllm_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
add_pooling_layer=True)
|
add_pooling_layer=True)
|
||||||
self.classifier = RowParallelLinear(config.hidden_size,
|
self.classifier = ReplicatedLinear(
|
||||||
config.num_labels,
|
config.hidden_size,
|
||||||
input_is_parallel=False,
|
config.num_labels,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(
|
params_dtype=vllm_config.model_config.head_dtype,
|
||||||
prefix, "classifier"),
|
prefix=maybe_prefix(prefix, "classifier"),
|
||||||
return_bias=False)
|
return_bias=False)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|||||||
@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
|
|||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "gpt2"))
|
prefix=maybe_prefix(prefix, "gpt2"))
|
||||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
self.score = nn.Linear(config.n_embd,
|
||||||
|
config.num_labels,
|
||||||
|
bias=False,
|
||||||
|
dtype=vllm_config.model_config.head_dtype)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
|||||||
"encode":
|
"encode":
|
||||||
Pooler.for_encode(pooler_config),
|
Pooler.for_encode(pooler_config),
|
||||||
"classify":
|
"classify":
|
||||||
Pooler.for_classify(pooler_config, classifier=None),
|
Pooler.for_classify(pooler_config, classifier=self.score),
|
||||||
})
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
|||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors)
|
intermediate_tensors=intermediate_tensors)
|
||||||
logits = self.score(hidden_states)
|
return hidden_states
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _add_transformer_prefix(
|
def _add_transformer_prefix(
|
||||||
|
|||||||
@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|||||||
delattr(self, attr)
|
delattr(self, attr)
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.v_head = RowParallelLinear(
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
config.hidden_size,
|
|
||||||
1,
|
self.v_head = RowParallelLinear(config.hidden_size,
|
||||||
bias=False,
|
1,
|
||||||
input_is_parallel=False,
|
bias=False,
|
||||||
prefix=maybe_prefix(prefix, "v_head"),
|
input_is_parallel=False,
|
||||||
)
|
params_dtype=self.head_dtype,
|
||||||
|
prefix=maybe_prefix(prefix, "v_head"),
|
||||||
|
return_bias=False)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
logits, _ = self.v_head(hidden_states)
|
hidden_states = hidden_states.to(self.head_dtype)
|
||||||
|
logits = self.v_head(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
num_labels,
|
num_labels,
|
||||||
bias=score_bias,
|
bias=score_bias,
|
||||||
dtype=torch.float32,
|
dtype=vllm_config.model_config.head_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -28,13 +28,17 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class JinaVLScorer(nn.Module):
|
class JinaVLScorer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(self, model_config: "ModelConfig"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config = model_config.hf_config
|
||||||
|
head_dtype = model_config.head_dtype
|
||||||
self.dense = ColumnParallelLinear(config.hidden_size,
|
self.dense = ColumnParallelLinear(config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
params_dtype=head_dtype,
|
||||||
bias=True)
|
bias=True)
|
||||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||||
config.num_labels,
|
config.num_labels,
|
||||||
|
params_dtype=head_dtype,
|
||||||
bias=True)
|
bias=True)
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "qwen2_vl"))
|
prefix=maybe_prefix(prefix, "qwen2_vl"))
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.score = JinaVLScorer(config)
|
self.score = JinaVLScorer(vllm_config.model_config)
|
||||||
self.pooler = DispatchPooler({
|
self.pooler = DispatchPooler({
|
||||||
"encode":
|
"encode":
|
||||||
Pooler.for_encode(pooler_config),
|
Pooler.for_encode(pooler_config),
|
||||||
|
|||||||
@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "modernbert"))
|
prefix=maybe_prefix(prefix, "modernbert"))
|
||||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
self.classifier = nn.Linear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
dtype=vllm_config.model_config.head_dtype)
|
||||||
self.pooling = ModernBertPooler(config)
|
self.pooling = ModernBertPooler(config)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
|||||||
@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
self.score = nn.Sequential(
|
self.score = nn.Sequential(
|
||||||
ColumnParallelLinear(config.hidden_size,
|
ColumnParallelLinear(config.hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
params_dtype=self.head_dtype,
|
||||||
return_bias=False),
|
return_bias=False),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
RowParallelLinear(config.hidden_size,
|
RowParallelLinear(config.hidden_size,
|
||||||
config.num_labels,
|
config.num_labels,
|
||||||
|
params_dtype=self.head_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
return_bias=False),
|
return_bias=False),
|
||||||
)
|
)
|
||||||
@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
|
hidden_states = hidden_states.to(self.head_dtype)
|
||||||
logits = self.score(hidden_states)
|
logits = self.score(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||||
DispatchPooler, Pooler)
|
DispatchPooler, Pooler)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
|
|||||||
class RobertaClassificationHead(nn.Module):
|
class RobertaClassificationHead(nn.Module):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
def __init__(self, config: RobertaConfig):
|
def __init__(self, model_config: "ModelConfig"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
config = model_config.hf_config
|
||||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
head_dtype = model_config.head_dtype
|
||||||
|
self.dense = nn.Linear(config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
dtype=head_dtype)
|
||||||
|
self.out_proj = nn.Linear(config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
dtype=head_dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# CLSPool has already been applied in `pooling`
|
# CLSPool has already been applied in `pooling`
|
||||||
@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
self.roberta = BertModel(vllm_config=vllm_config,
|
self.roberta = BertModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "bert"),
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
embedding_class=RobertaEmbedding)
|
embedding_class=RobertaEmbedding)
|
||||||
self.classifier = RobertaClassificationHead(config)
|
self.classifier = RobertaClassificationHead(vllm_config.model_config)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user