diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 243cb92ae256..c45fc7e649ec 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -3,8 +3,6 @@ Run `pytest tests/models/test_transformers.py`. """ -from contextlib import nullcontext - import pytest from ..conftest import HfRunner, VllmRunner @@ -42,7 +40,6 @@ def check_implementation( "model,model_impl", [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), - ("openai-community/gpt2", "transformers"), ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE ]) # trust_remote_code=True by default def test_models( @@ -52,20 +49,11 @@ def test_models( model: str, model_impl: str, ) -> None: - - maybe_raises = nullcontext() - if model == "openai-community/gpt2" and model_impl == "transformers": - # Model is not backend compatible - maybe_raises = pytest.raises( - ValueError, - match="The Transformers implementation.*not compatible with vLLM") - - with maybe_raises: - check_implementation(hf_runner, - vllm_runner, - example_prompts, - model, - model_impl=model_impl) + check_implementation(hf_runner, + vllm_runner, + example_prompts, + model, + model_impl=model_impl) @multi_gpu_test(num_gpus=2) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 65a1676c0637..75ac326aaa3d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1613,14 +1613,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No TransformersModel support so far. - if (model_config.model_impl == ModelImpl.TRANSFORMERS - or model_config.model_impl == "transformers"): - _raise_or_fallback( - feature_name=f"model_impl={model_config.model_impl}", - recommend_to_remove=False) - return False - # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills != EngineArgs.max_num_partial_prefills diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fe6a9d7a4aa4..56ec00dcf222 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -24,6 +24,7 @@ from transformers import AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -109,6 +110,7 @@ def replace_linear_class( ) +@support_torch_compile class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens"