diff --git a/tests/conftest.py b/tests/conftest.py index 764374a779d9e..55efc56ec3d02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,15 @@ import contextlib import gc import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, TypeVar import pytest import torch +import torch.nn as nn import torch.nn.functional as F from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, - LlavaConfig, LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, BatchEncoding) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig @@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = { "float": torch.float, } -AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) - -_EMBEDDING_MODELS = [ - "intfloat/e5-mistral-7b-instruct", -] +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) class HfRunner: - def wrap_device(self, input: any): + def wrap_device(self, input: _T) -> _T: if not is_cpu(): return input.to("cuda") else: @@ -163,13 +160,16 @@ class HfRunner: self, model_name: str, dtype: str = "half", + *, + is_embedding_model: bool = False, + is_vision_model: bool = False, ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] self.model_name = model_name - if model_name in _EMBEDDING_MODELS: + if is_embedding_model: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = self.wrap_device( @@ -178,8 +178,13 @@ class HfRunner: device="cpu", ).to(dtype=torch_dtype)) else: + if is_vision_model: + auto_cls = AutoModelForVision2Seq + else: + auto_cls = AutoModelForCausalLM + self.model = self.wrap_device( - AutoModelForCausalLM.from_pretrained( + auto_cls.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True, diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index 59bf054913f7c..668ed3a520a36 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -28,7 +28,7 @@ def test_models( model: str, dtype: str, ) -> None: - hf_model = hf_runner(model, dtype=dtype) + hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True) hf_outputs = hf_model.encode(example_prompts) del hf_model diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index cc0685ca9c5eb..839a9f78d1bb8 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, """ model_id, vision_language_config = model_and_config - hf_model = hf_runner(model_id, dtype=dtype) + hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) hf_outputs = hf_model.generate_greedy(hf_image_prompts, max_tokens, images=hf_images)