mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 22:55:35 +08:00
Enable headless models for pooling in the Transformers backend (#21767)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
ac45c44d98
commit
38c8bce8b6
@ -525,6 +525,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
|
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
|
||||||
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||||
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
|
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,8 +34,7 @@ def check_implementation(
|
|||||||
|
|
||||||
with runner_test(model, **kwargs_test, **kwargs) as model_test:
|
with runner_test(model, **kwargs_test, **kwargs) as model_test:
|
||||||
model_config = model_test.llm.llm_engine.model_config
|
model_config = model_test.llm.llm_engine.model_config
|
||||||
assert model_config.architecture == (
|
assert model_config.using_transformers_backend()
|
||||||
model_config._get_transformers_backend_cls())
|
|
||||||
|
|
||||||
outputs_test = model_test.generate_greedy_logprobs(*args)
|
outputs_test = model_test.generate_greedy_logprobs(*args)
|
||||||
|
|
||||||
@ -135,8 +134,7 @@ def test_quantization(
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
assert model_config.architecture == (
|
assert model_config.using_transformers_backend()
|
||||||
model_config._get_transformers_backend_cls())
|
|
||||||
|
|
||||||
transformers_outputs = vllm_model.generate_greedy_logprobs(
|
transformers_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
|
||||||
@ -149,6 +147,25 @@ def test_quantization(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
# Layers live in `layers`
|
||||||
|
"Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
# Layers live in `model.layers`
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_embed_loading(vllm_runner, model):
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=1024,
|
||||||
|
enforce_eager=True,
|
||||||
|
runner="pooling",
|
||||||
|
model_impl="transformers") as model_test:
|
||||||
|
model_config = model_test.llm.llm_engine.model_config
|
||||||
|
assert model_config.using_transformers_backend()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||||
@ -169,8 +186,7 @@ def test_classify(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
model_impl="transformers") as vllm_model:
|
model_impl="transformers") as vllm_model:
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
assert model_config.architecture == (
|
assert model_config.using_transformers_backend()
|
||||||
model_config._get_transformers_backend_cls())
|
|
||||||
|
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
|
|||||||
@ -812,13 +812,18 @@ class ModelConfig:
|
|||||||
def _get_transformers_backend_cls(self) -> str:
|
def _get_transformers_backend_cls(self) -> str:
|
||||||
"""Determine which Transformers backend class will be used if
|
"""Determine which Transformers backend class will be used if
|
||||||
`model_impl` is set to `transformers` or `auto`."""
|
`model_impl` is set to `transformers` or `auto`."""
|
||||||
|
if getattr(self, "runner_type", self.runner) == "pooling":
|
||||||
|
return "TransformersModel"
|
||||||
if self.hf_config != self.hf_text_config:
|
if self.hf_config != self.hf_text_config:
|
||||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||||
# probably a composite config, i.e. multimodal
|
# probably a composite config, i.e. multimodal
|
||||||
return "TransformersForMultimodalLM"
|
return "TransformersForMultimodalLM"
|
||||||
else:
|
|
||||||
return "TransformersForCausalLM"
|
return "TransformersForCausalLM"
|
||||||
|
|
||||||
|
def using_transformers_backend(self) -> bool:
|
||||||
|
"""Check if the model is using the Transformers backend class."""
|
||||||
|
return self.architecture == self._get_transformers_backend_cls()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def registry(self):
|
def registry(self):
|
||||||
return me_models.ModelRegistry
|
return me_models.ModelRegistry
|
||||||
|
|||||||
@ -270,8 +270,9 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
"TransformersModel": ("transformers", "TransformersModel"),
|
||||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
|
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
||||||
}
|
}
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
|
|||||||
@ -651,6 +651,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class TransformersModel(TransformersBase):
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# Add `model.` prefix for base model checkpoints
|
||||||
|
"": "model.",
|
||||||
|
# Remove `model.` from places it should not be
|
||||||
|
"model.model.": "model.",
|
||||||
|
"model.score": "score",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class TransformersForCausalLM(TransformersBase):
|
class TransformersForCausalLM(TransformersBase):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user