diff --git a/tests/conftest.py b/tests/conftest.py index 3f3790cab8d3..2bf88abb0f6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -456,7 +456,15 @@ class HfRunner: outputs = [] for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - logits = output.logits.softmax(dim=-1)[0].tolist() + + problem_type = getattr(self.config, "problem_type", "") + + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() outputs.append(logits) return outputs diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py new file mode 100644 index 000000000000..45366f209414 --- /dev/null +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["Rami/multi-label-class-classification-on-github-issues"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e2162e5cbf95..75e65072b701 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -172,6 +172,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: def get_classification_activation_function(config: PretrainedConfig): + # Implement alignment with transformers ForSequenceClassificationLoss + # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 + problem_type = getattr(config, "problem_type", "") + if problem_type == "regression": + return PoolerIdentity() + if problem_type == "single_label_classification": + return PoolerClassify() + if problem_type == "multi_label_classification": + return PoolerMultiLabelClassify() return PoolerClassify() @@ -409,6 +418,12 @@ class PoolerNormalize(PoolerActivation): return x.to(pooled_data.dtype) +class PoolerMultiLabelClassify(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + class PoolerClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: