[Model] Automatic conversion of TokenClassification model (#30666)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi 2025-12-15 16:13:00 +08:00 committed by GitHub
parent 33278073d6
commit 4429d934de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 0 deletions

View File

@ -68,3 +68,34 @@ def test_modernbert_models(
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_auto_conversion(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.token_classify(example_prompts)
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, atol=1e-2)

View File

@ -573,6 +573,7 @@ _AUTOMATIC_CONVERTED_MODELS = {
"Qwen3ForSequenceClassification": _HfExamplesInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
),
"Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
}
_MULTIMODAL_EXAMPLE_MODELS = {

View File

@ -1796,6 +1796,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForTokenClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),

View File

@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
tokens = getattr(text_config, "classifier_from_token", None)
method = getattr(text_config, "method", None)
def auto_set_score_bias(weights):
for name, weight in weights:
if name == "score.bias":
device = self.score.weight.device
dtype = self.score.weight.dtype
bias = weight.to(device).to(dtype)
self.score.bias = torch.nn.Parameter(bias)
self.score.skip_bias_add = False
else:
yield name, weight
weights = auto_set_score_bias(weights)
if tokens is None and method is None:
return super().load_weights(weights)
else: