[Model] Add multi_label_classification support (#23173)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-08-19 20:54:30 +08:00 committed by GitHub
parent 03752dba8f
commit f856c33ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 1 deletions

View File

@ -456,6 +456,14 @@ class HfRunner:
outputs = []
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
problem_type = getattr(self.config, "problem_type", "")
if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":
logits = output.logits.sigmoid()[0].tolist()
else:
logits = output.logits.softmax(dim=-1)[0].tolist()
outputs.append(logits)

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForSequenceClassification
@pytest.mark.parametrize(
"model",
["Rami/multi-label-class-classification-on-github-issues"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -172,6 +172,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
def get_classification_activation_function(config: PretrainedConfig):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify()
@ -409,6 +418,12 @@ class PoolerNormalize(PoolerActivation):
return x.to(pooled_data.dtype)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
class PoolerClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: