mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 13:35:36 +08:00
Transformers backend already supports V1 (#15463)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
ff38f0a32c
commit
e42389f9d7
@ -3,8 +3,6 @@
|
|||||||
|
|
||||||
Run `pytest tests/models/test_transformers.py`.
|
Run `pytest tests/models/test_transformers.py`.
|
||||||
"""
|
"""
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
@ -42,7 +40,6 @@ def check_implementation(
|
|||||||
"model,model_impl",
|
"model,model_impl",
|
||||||
[
|
[
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||||
("openai-community/gpt2", "transformers"),
|
|
||||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||||
]) # trust_remote_code=True by default
|
]) # trust_remote_code=True by default
|
||||||
def test_models(
|
def test_models(
|
||||||
@ -52,15 +49,6 @@ def test_models(
|
|||||||
model: str,
|
model: str,
|
||||||
model_impl: str,
|
model_impl: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
maybe_raises = nullcontext()
|
|
||||||
if model == "openai-community/gpt2" and model_impl == "transformers":
|
|
||||||
# Model is not backend compatible
|
|
||||||
maybe_raises = pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="The Transformers implementation.*not compatible with vLLM")
|
|
||||||
|
|
||||||
with maybe_raises:
|
|
||||||
check_implementation(hf_runner,
|
check_implementation(hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
|
|||||||
@ -1613,14 +1613,6 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# No TransformersModel support so far.
|
|
||||||
if (model_config.model_impl == ModelImpl.TRANSFORMERS
|
|
||||||
or model_config.model_impl == "transformers"):
|
|
||||||
_raise_or_fallback(
|
|
||||||
feature_name=f"model_impl={model_config.model_impl}",
|
|
||||||
recommend_to_remove=False)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No Concurrent Partial Prefills so far.
|
# No Concurrent Partial Prefills so far.
|
||||||
if (self.max_num_partial_prefills
|
if (self.max_num_partial_prefills
|
||||||
!= EngineArgs.max_num_partial_prefills
|
!= EngineArgs.max_num_partial_prefills
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from transformers import AutoModel, PretrainedConfig, PreTrainedModel
|
|||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, VllmConfig)
|
ParallelConfig, VllmConfig)
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
@ -109,6 +110,7 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"
|
embedding_modules = ["embed_tokens"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user