mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 22:17:58 +08:00
[Model] Automatic conversion of TokenClassification model (#30666)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
33278073d6
commit
4429d934de
@ -68,3 +68,34 @@ def test_modernbert_models(
|
|||||||
hf_output = torch.tensor(hf_output).cpu().float()
|
hf_output = torch.tensor(hf_output).cpu().float()
|
||||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||||
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_auto_conversion(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||||
|
|
||||||
|
with hf_runner(
|
||||||
|
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||||
|
) as hf_model:
|
||||||
|
tokenizer = hf_model.tokenizer
|
||||||
|
hf_outputs = []
|
||||||
|
for prompt in example_prompts:
|
||||||
|
inputs = tokenizer([prompt], return_tensors="pt")
|
||||||
|
inputs = hf_model.wrap_device(inputs)
|
||||||
|
output = hf_model.model(**inputs)
|
||||||
|
hf_outputs.append(softmax(output.logits[0]))
|
||||||
|
|
||||||
|
# check logits difference
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
hf_output = torch.tensor(hf_output).cpu().float()
|
||||||
|
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||||
|
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
||||||
|
|||||||
@ -573,6 +573,7 @@ _AUTOMATIC_CONVERTED_MODELS = {
|
|||||||
"Qwen3ForSequenceClassification": _HfExamplesInfo(
|
"Qwen3ForSequenceClassification": _HfExamplesInfo(
|
||||||
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||||
),
|
),
|
||||||
|
"Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@ -1796,6 +1796,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
|
|||||||
("ForTextEncoding", ("pooling", "embed")),
|
("ForTextEncoding", ("pooling", "embed")),
|
||||||
("EmbeddingModel", ("pooling", "embed")),
|
("EmbeddingModel", ("pooling", "embed")),
|
||||||
("ForSequenceClassification", ("pooling", "classify")),
|
("ForSequenceClassification", ("pooling", "classify")),
|
||||||
|
("ForTokenClassification", ("pooling", "classify")),
|
||||||
("ForAudioClassification", ("pooling", "classify")),
|
("ForAudioClassification", ("pooling", "classify")),
|
||||||
("ForImageClassification", ("pooling", "classify")),
|
("ForImageClassification", ("pooling", "classify")),
|
||||||
("ForVideoClassification", ("pooling", "classify")),
|
("ForVideoClassification", ("pooling", "classify")),
|
||||||
|
|||||||
@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
tokens = getattr(text_config, "classifier_from_token", None)
|
tokens = getattr(text_config, "classifier_from_token", None)
|
||||||
method = getattr(text_config, "method", None)
|
method = getattr(text_config, "method", None)
|
||||||
|
|
||||||
|
def auto_set_score_bias(weights):
|
||||||
|
for name, weight in weights:
|
||||||
|
if name == "score.bias":
|
||||||
|
device = self.score.weight.device
|
||||||
|
dtype = self.score.weight.dtype
|
||||||
|
bias = weight.to(device).to(dtype)
|
||||||
|
self.score.bias = torch.nn.Parameter(bias)
|
||||||
|
self.score.skip_bias_add = False
|
||||||
|
else:
|
||||||
|
yield name, weight
|
||||||
|
|
||||||
|
weights = auto_set_score_bias(weights)
|
||||||
if tokens is None and method is None:
|
if tokens is None and method is None:
|
||||||
return super().load_weights(weights)
|
return super().load_weights(weights)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user