mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 12:44:04 +08:00
[CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166)
This commit is contained in:
parent
65bf2ac165
commit
e9cdd2b1e2
@ -1,19 +1,21 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoProcessor,
|
||||
LlavaForConditionalGeneration)
|
||||
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||
LlavaConfig, LlavaForConditionalGeneration)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
@ -129,9 +131,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
_VISION_LANGUAGE_MODELS = {
|
||||
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
|
||||
}
|
||||
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
|
||||
|
||||
_EMBEDDING_MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
@ -143,23 +143,14 @@ class HfRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
self.model_name = model_name
|
||||
if model_name in _VISION_LANGUAGE_MODELS:
|
||||
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
elif model_name in _EMBEDDING_MODELS:
|
||||
|
||||
if model_name in _EMBEDDING_MODELS:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(
|
||||
@ -172,10 +163,24 @@ class HfRunner:
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
self.processor = None
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unable to auto-load processor from HuggingFace for "
|
||||
"model %s. Using tokenizer instead.", model_name)
|
||||
self.processor = self.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -187,19 +192,19 @@ class HfRunner:
|
||||
if images:
|
||||
assert len(prompts) == len(images)
|
||||
for i, prompt in enumerate(prompts):
|
||||
if self.model_name not in _VISION_LANGUAGE_MODELS:
|
||||
input_ids = self.tokenizer(prompt,
|
||||
return_tensors="pt").input_ids
|
||||
inputs = {"input_ids": input_ids.cuda()}
|
||||
else:
|
||||
image = images[i] if images else None
|
||||
inputs = self.processor(text=prompt,
|
||||
images=image,
|
||||
return_tensors="pt")
|
||||
inputs = {
|
||||
key: value.cuda() if value is not None else None
|
||||
for key, value in inputs.items()
|
||||
}
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = {
|
||||
key: value.cuda() if value is not None else None
|
||||
for key, value in inputs.items()
|
||||
}
|
||||
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
use_cache=True,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user