diff --git a/tests/conftest.py b/tests/conftest.py index 999ace2c3c699..c1a44a606e1bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,21 @@ import contextlib import gc import os -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest import torch from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, - LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + LlavaConfig, LlavaForConditionalGeneration) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.logger import init_logger from vllm.sequence import MultiModalData -from vllm.transformers_utils.tokenizer import get_tokenizer + +logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -129,9 +131,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = { "float": torch.float, } -_VISION_LANGUAGE_MODELS = { - "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, -} +AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) _EMBEDDING_MODELS = [ "intfloat/e5-mistral-7b-instruct", @@ -143,23 +143,14 @@ class HfRunner: def __init__( self, model_name: str, - tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name - if model_name in _VISION_LANGUAGE_MODELS: - self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - ) - elif model_name in _EMBEDDING_MODELS: + + if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = SentenceTransformer( @@ -172,10 +163,24 @@ class HfRunner: torch_dtype=torch_dtype, trust_remote_code=True, ).cuda() - self.processor = None - if tokenizer_name is None: - tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", model_name) + self.processor = self.tokenizer def generate( self, @@ -187,19 +192,19 @@ class HfRunner: if images: assert len(prompts) == len(images) for i, prompt in enumerate(prompts): - if self.model_name not in _VISION_LANGUAGE_MODELS: - input_ids = self.tokenizer(prompt, - return_tensors="pt").input_ids - inputs = {"input_ids": input_ids.cuda()} - else: - image = images[i] if images else None - inputs = self.processor(text=prompt, - images=image, - return_tensors="pt") - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } + output_ids = self.model.generate( **inputs, use_cache=True,