mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
Enable headless models for pooling in the Transformers backend (#21767)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
ac45c44d98
commit
38c8bce8b6
@ -525,6 +525,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
|
||||
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
|
||||
}
|
||||
|
||||
@ -34,8 +34,7 @@ def check_implementation(
|
||||
|
||||
with runner_test(model, **kwargs_test, **kwargs) as model_test:
|
||||
model_config = model_test.llm.llm_engine.model_config
|
||||
assert model_config.architecture == (
|
||||
model_config._get_transformers_backend_cls())
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
outputs_test = model_test.generate_greedy_logprobs(*args)
|
||||
|
||||
@ -135,8 +134,7 @@ def test_quantization(
|
||||
enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.architecture == (
|
||||
model_config._get_transformers_backend_cls())
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
transformers_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||
@ -149,6 +147,25 @@ def test_quantization(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# Layers live in `layers`
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
# Layers live in `model.layers`
|
||||
"meta-llama/Llama-3.2-1B-Instruct"
|
||||
],
|
||||
)
|
||||
def test_embed_loading(vllm_runner, model):
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
model_impl="transformers") as model_test:
|
||||
model_config = model_test.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
@ -169,8 +186,7 @@ def test_classify(
|
||||
dtype=dtype,
|
||||
model_impl="transformers") as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.architecture == (
|
||||
model_config._get_transformers_backend_cls())
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
|
||||
@ -812,12 +812,17 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
if getattr(self, "runner_type", self.runner) == "pooling":
|
||||
return "TransformersModel"
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
return "TransformersForMultimodalLM"
|
||||
else:
|
||||
return "TransformersForCausalLM"
|
||||
return "TransformersForCausalLM"
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
return self.architecture == self._get_transformers_backend_cls()
|
||||
|
||||
@property
|
||||
def registry(self):
|
||||
|
||||
@ -270,8 +270,9 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
"TransformersModel": ("transformers", "TransformersModel"),
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@ -651,6 +651,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersModel(TransformersBase):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Add `model.` prefix for base model checkpoints
|
||||
"": "model.",
|
||||
# Remove `model.` from places it should not be
|
||||
"model.model.": "model.",
|
||||
"model.score": "score",
|
||||
})
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersForCausalLM(TransformersBase):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user