mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Model] Add multi_label_classification support (#23173)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
03752dba8f
commit
f856c33ce9
@ -456,7 +456,15 @@ class HfRunner:
|
||||
outputs = []
|
||||
for inputs in all_inputs:
|
||||
output = self.model(**self.wrap_device(inputs))
|
||||
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||||
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
if problem_type == "regression":
|
||||
logits = output.logits[0].tolist()
|
||||
elif problem_type == "multi_label_classification":
|
||||
logits = output.logits.sigmoid()[0].tolist()
|
||||
else:
|
||||
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||||
outputs.append(logits)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["Rami/multi-label-class-classification-on-github-issues"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_classify_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
@ -172,6 +172,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||
problem_type = getattr(config, "problem_type", "")
|
||||
if problem_type == "regression":
|
||||
return PoolerIdentity()
|
||||
if problem_type == "single_label_classification":
|
||||
return PoolerClassify()
|
||||
if problem_type == "multi_label_classification":
|
||||
return PoolerMultiLabelClassify()
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
@ -409,6 +418,12 @@ class PoolerNormalize(PoolerActivation):
|
||||
return x.to(pooled_data.dtype)
|
||||
|
||||
|
||||
class PoolerMultiLabelClassify(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||
|
||||
|
||||
class PoolerClassify(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user