[Model] Classification models support logit_bias / sigmoid_normalize (#24031)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
wang.yuqi 2025-09-03 00:48:57 +08:00 committed by GitHub
parent 38ba061f6f
commit e0653f6c0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 30 deletions

View File

@ -2651,24 +2651,46 @@ class PoolerConfig:
## for embeddings models ## for embeddings models
normalize: Optional[bool] = None normalize: Optional[bool] = None
""" """
Whether to normalize the embeddings outputs. Whether to normalize the embeddings outputs. Defaults to True.
""" """
dimensions: Optional[int] = None dimensions: Optional[int] = None
""" """
Reduce the dimensions of embeddings if model Reduce the dimensions of embeddings if model
support matryoshka representation. support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
""" """
## for classification models ## for classification models
activation: Optional[bool] = None activation: Optional[bool] = None
""" """
Whether to apply activation function to the classification outputs. Whether to apply activation function to the classification outputs.
Defaults to True.
"""
logit_bias: Optional[float] = None
"""
If provided, apply classification logit biases. Defaults to None.
""" """
## for reward models ## for reward models
softmax: Optional[bool] = None softmax: Optional[bool] = None
""" """
Whether to apply softmax to the reward outputs. Whether to apply softmax to the reward outputs.
Defaults to True.
""" """
step_tag_id: Optional[int] = None step_tag_id: Optional[int] = None
""" """
@ -2683,25 +2705,6 @@ class PoolerConfig:
``math-shepherd-mistral-7b-prm`` model. ``math-shepherd-mistral-7b-prm`` model.
""" """
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
This parameter enables accepting long inputs without requiring
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
max_embed_len, it will be handled according to the original max_model_len
validation logic. Defaults to None (i.e. set to max_model_len).
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@ -633,9 +633,14 @@ class ClassifierPooler(Pooler):
) -> None: ) -> None:
super().__init__() super().__init__()
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.pooling = pooling self.pooling = pooling
self.classifier = classifier self.classifier = classifier
self.act_fn = act_fn or PoolerClassify() self.act_fn = act_fn or PoolerClassify()
self.logit_bias: Optional[
float] = vllm_config.model_config.pooler_config.logit_bias
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"} return {"classify", "score"}
@ -654,6 +659,9 @@ class ClassifierPooler(Pooler):
pooled_data = self.classifier(pooled_data) pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels] # pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params] flags = [p.activation for p in pooling_params]

View File

@ -210,8 +210,10 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
config.num_labels = 1 config.num_labels = 1
pooler_config = vllm_config.model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

View File

@ -92,17 +92,14 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
# logit bias for sigmoid normalization
self.LOGIT_BIAS = 2.65
self.score = JinaVLScorer(config) self.score = JinaVLScorer(config)
self.pooler = DispatchPooler({ self.pooler = DispatchPooler({
"encode": "encode":
Pooler.for_encode(pooler_config), Pooler.for_encode(pooler_config),
"classify": "classify":
Pooler.for_classify(pooler_config, classifier=None), Pooler.for_classify(pooler_config, classifier=self.score),
"score": "score":
Pooler.for_classify(pooler_config, classifier=None), Pooler.for_classify(pooler_config, classifier=self.score),
}) })
@classmethod @classmethod
@ -137,9 +134,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**kwargs, **kwargs,
) )
return hidden_states
logits = self.score(hidden_states) - self.LOGIT_BIAS
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)