mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
[CI/Build] Simplify model loading for HfRunner (#5251)
This commit is contained in:
parent
27208be66e
commit
9ba093b4f4
@ -1,14 +1,15 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||
LlavaConfig, LlavaForConditionalGeneration)
|
||||
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||
AutoProcessor, AutoTokenizer, BatchEncoding)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
|
||||
|
||||
_EMBEDDING_MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
]
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, input: any):
|
||||
def wrap_device(self, input: _T) -> _T:
|
||||
if not is_cpu():
|
||||
return input.to("cuda")
|
||||
else:
|
||||
@ -163,13 +160,16 @@ class HfRunner:
|
||||
self,
|
||||
model_name: str,
|
||||
dtype: str = "half",
|
||||
*,
|
||||
is_embedding_model: bool = False,
|
||||
is_vision_model: bool = False,
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
if model_name in _EMBEDDING_MODELS:
|
||||
if is_embedding_model:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = self.wrap_device(
|
||||
@ -178,8 +178,13 @@ class HfRunner:
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype))
|
||||
else:
|
||||
if is_vision_model:
|
||||
auto_cls = AutoModelForVision2Seq
|
||||
else:
|
||||
auto_cls = AutoModelForCausalLM
|
||||
|
||||
self.model = self.wrap_device(
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
auto_cls.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -28,7 +28,7 @@ def test_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
del hf_model
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
"""
|
||||
model_id, vision_language_config = model_and_config
|
||||
|
||||
hf_model = hf_runner(model_id, dtype=dtype)
|
||||
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
|
||||
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user