Enable headless models for pooling in the Transformers backend (#21767)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-01 18:31:29 +01:00 committed by GitHub
parent ac45c44d98
commit 38c8bce8b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 9 deletions

View File

@ -525,6 +525,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}

View File

@ -34,8 +34,7 @@ def check_implementation(
with runner_test(model, **kwargs_test, **kwargs) as model_test:
model_config = model_test.llm.llm_engine.model_config
assert model_config.architecture == (
model_config._get_transformers_backend_cls())
assert model_config.using_transformers_backend()
outputs_test = model_test.generate_greedy_logprobs(*args)
@ -135,8 +134,7 @@ def test_quantization(
enforce_eager=True,
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.architecture == (
model_config._get_transformers_backend_cls())
assert model_config.using_transformers_backend()
transformers_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
@ -149,6 +147,25 @@ def test_quantization(
)
@pytest.mark.parametrize(
"model",
[
# Layers live in `layers`
"Qwen/Qwen3-Embedding-0.6B",
# Layers live in `model.layers`
"meta-llama/Llama-3.2-1B-Instruct"
],
)
def test_embed_loading(vllm_runner, model):
with vllm_runner(model,
max_model_len=1024,
enforce_eager=True,
runner="pooling",
model_impl="transformers") as model_test:
model_config = model_test.llm.llm_engine.model_config
assert model_config.using_transformers_backend()
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
@ -169,8 +186,7 @@ def test_classify(
dtype=dtype,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.architecture == (
model_config._get_transformers_backend_cls())
assert model_config.using_transformers_backend()
vllm_outputs = vllm_model.classify(example_prompts)

View File

@ -812,12 +812,17 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if getattr(self, "runner_type", self.runner) == "pooling":
return "TransformersModel"
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
else:
return "TransformersForCausalLM"
return "TransformersForCausalLM"
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
return self.architecture == self._get_transformers_backend_cls()
@property
def registry(self):

View File

@ -270,8 +270,9 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersModel": ("transformers", "TransformersModel"),
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
}
# yapf: enable

View File

@ -651,6 +651,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@support_torch_compile
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` from places it should not be
"model.model.": "model.",
"model.score": "score",
})
@support_torch_compile
class TransformersForCausalLM(TransformersBase):