mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Model] Add multi_label_classification support (#23173)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
03752dba8f
commit
f856c33ce9
@ -456,7 +456,15 @@ class HfRunner:
|
|||||||
outputs = []
|
outputs = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
output = self.model(**self.wrap_device(inputs))
|
output = self.model(**self.wrap_device(inputs))
|
||||||
logits = output.logits.softmax(dim=-1)[0].tolist()
|
|
||||||
|
problem_type = getattr(self.config, "problem_type", "")
|
||||||
|
|
||||||
|
if problem_type == "regression":
|
||||||
|
logits = output.logits[0].tolist()
|
||||||
|
elif problem_type == "multi_label_classification":
|
||||||
|
logits = output.logits.sigmoid()[0].tolist()
|
||||||
|
else:
|
||||||
|
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||||||
outputs.append(logits)
|
outputs.append(logits)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["Rami/multi-label-class-classification-on-github-issues"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_classify_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
|
with hf_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||||
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
hf_output = torch.tensor(hf_output)
|
||||||
|
vllm_output = torch.tensor(vllm_output)
|
||||||
|
|
||||||
|
assert torch.allclose(hf_output, vllm_output,
|
||||||
|
1e-3 if dtype == "float" else 1e-2)
|
||||||
@ -172,6 +172,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
|||||||
|
|
||||||
|
|
||||||
def get_classification_activation_function(config: PretrainedConfig):
|
def get_classification_activation_function(config: PretrainedConfig):
|
||||||
|
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||||
|
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||||
|
problem_type = getattr(config, "problem_type", "")
|
||||||
|
if problem_type == "regression":
|
||||||
|
return PoolerIdentity()
|
||||||
|
if problem_type == "single_label_classification":
|
||||||
|
return PoolerClassify()
|
||||||
|
if problem_type == "multi_label_classification":
|
||||||
|
return PoolerMultiLabelClassify()
|
||||||
return PoolerClassify()
|
return PoolerClassify()
|
||||||
|
|
||||||
|
|
||||||
@ -409,6 +418,12 @@ class PoolerNormalize(PoolerActivation):
|
|||||||
return x.to(pooled_data.dtype)
|
return x.to(pooled_data.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolerMultiLabelClassify(PoolerActivation):
|
||||||
|
|
||||||
|
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||||
|
|
||||||
|
|
||||||
class PoolerClassify(PoolerActivation):
|
class PoolerClassify(PoolerActivation):
|
||||||
|
|
||||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user