mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 22:46:01 +08:00
[Model] Classification models support logit_bias / sigmoid_normalize (#24031)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
38ba061f6f
commit
e0653f6c0b
@ -2651,24 +2651,46 @@ class PoolerConfig:
|
|||||||
## for embeddings models
|
## for embeddings models
|
||||||
normalize: Optional[bool] = None
|
normalize: Optional[bool] = None
|
||||||
"""
|
"""
|
||||||
Whether to normalize the embeddings outputs.
|
Whether to normalize the embeddings outputs. Defaults to True.
|
||||||
"""
|
"""
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
"""
|
"""
|
||||||
Reduce the dimensions of embeddings if model
|
Reduce the dimensions of embeddings if model
|
||||||
support matryoshka representation.
|
support matryoshka representation. Defaults to None.
|
||||||
|
"""
|
||||||
|
enable_chunked_processing: Optional[bool] = None
|
||||||
|
"""
|
||||||
|
Whether to enable chunked processing for long inputs that exceed the model's
|
||||||
|
maximum position embeddings. When enabled, long inputs will be split into
|
||||||
|
chunks, processed separately, and then aggregated using weighted averaging.
|
||||||
|
This allows embedding models to handle arbitrarily long text without CUDA
|
||||||
|
errors. Defaults to False.
|
||||||
|
"""
|
||||||
|
max_embed_len: Optional[int] = None
|
||||||
|
"""
|
||||||
|
Maximum input length allowed for embedding generation. When set, allows
|
||||||
|
inputs longer than max_embed_len to be accepted for embedding models.
|
||||||
|
When an input exceeds max_embed_len, it will be handled according to
|
||||||
|
the original max_model_len validation logic.
|
||||||
|
Defaults to None (i.e. set to max_model_len).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
## for classification models
|
## for classification models
|
||||||
activation: Optional[bool] = None
|
activation: Optional[bool] = None
|
||||||
"""
|
"""
|
||||||
Whether to apply activation function to the classification outputs.
|
Whether to apply activation function to the classification outputs.
|
||||||
|
Defaults to True.
|
||||||
|
"""
|
||||||
|
logit_bias: Optional[float] = None
|
||||||
|
"""
|
||||||
|
If provided, apply classification logit biases. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
## for reward models
|
## for reward models
|
||||||
softmax: Optional[bool] = None
|
softmax: Optional[bool] = None
|
||||||
"""
|
"""
|
||||||
Whether to apply softmax to the reward outputs.
|
Whether to apply softmax to the reward outputs.
|
||||||
|
Defaults to True.
|
||||||
"""
|
"""
|
||||||
step_tag_id: Optional[int] = None
|
step_tag_id: Optional[int] = None
|
||||||
"""
|
"""
|
||||||
@ -2683,25 +2705,6 @@ class PoolerConfig:
|
|||||||
``math-shepherd-mistral-7b-prm`` model.
|
``math-shepherd-mistral-7b-prm`` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
enable_chunked_processing: Optional[bool] = None
|
|
||||||
"""
|
|
||||||
Whether to enable chunked processing for long inputs that exceed the model's
|
|
||||||
maximum position embeddings. When enabled, long inputs will be split into
|
|
||||||
chunks, processed separately, and then aggregated using weighted averaging.
|
|
||||||
This allows embedding models to handle arbitrarily long text without CUDA
|
|
||||||
errors. Defaults to False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_embed_len: Optional[int] = None
|
|
||||||
"""
|
|
||||||
Maximum input length allowed for embedding generation. When set, allows
|
|
||||||
inputs longer than max_embed_len to be accepted for embedding models.
|
|
||||||
This parameter enables accepting long inputs without requiring
|
|
||||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
|
|
||||||
max_embed_len, it will be handled according to the original max_model_len
|
|
||||||
validation logic. Defaults to None (i.e. set to max_model_len).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -633,9 +633,14 @@ class ClassifierPooler(Pooler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.act_fn = act_fn or PoolerClassify()
|
self.act_fn = act_fn or PoolerClassify()
|
||||||
|
self.logit_bias: Optional[
|
||||||
|
float] = vllm_config.model_config.pooler_config.logit_bias
|
||||||
|
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"classify", "score"}
|
return {"classify", "score"}
|
||||||
@ -654,6 +659,9 @@ class ClassifierPooler(Pooler):
|
|||||||
pooled_data = self.classifier(pooled_data)
|
pooled_data = self.classifier(pooled_data)
|
||||||
# pooled_data shape: [batchsize, num_labels]
|
# pooled_data shape: [batchsize, num_labels]
|
||||||
|
|
||||||
|
if self.logit_bias is not None:
|
||||||
|
pooled_data -= self.logit_bias
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
flags = [p.activation for p in pooling_params]
|
flags = [p.activation for p in pooling_params]
|
||||||
|
|
||||||
|
|||||||
@ -210,8 +210,10 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
config.num_labels = 1
|
config.num_labels = 1
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
if pooler_config.logit_bias is None:
|
||||||
|
pooler_config.logit_bias = 2.65
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|||||||
@ -92,17 +92,14 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
# logit bias for sigmoid normalization
|
|
||||||
self.LOGIT_BIAS = 2.65
|
|
||||||
|
|
||||||
self.score = JinaVLScorer(config)
|
self.score = JinaVLScorer(config)
|
||||||
self.pooler = DispatchPooler({
|
self.pooler = DispatchPooler({
|
||||||
"encode":
|
"encode":
|
||||||
Pooler.for_encode(pooler_config),
|
Pooler.for_encode(pooler_config),
|
||||||
"classify":
|
"classify":
|
||||||
Pooler.for_classify(pooler_config, classifier=None),
|
Pooler.for_classify(pooler_config, classifier=self.score),
|
||||||
"score":
|
"score":
|
||||||
Pooler.for_classify(pooler_config, classifier=None),
|
Pooler.for_classify(pooler_config, classifier=self.score),
|
||||||
})
|
})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -137,9 +134,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
return hidden_states
|
||||||
logits = self.score(hidden_states) - self.LOGIT_BIAS
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user