mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[CI/Build] Simplify model loading for HfRunner (#5251)
This commit is contained in:
parent
27208be66e
commit
9ba093b4f4
@ -1,14 +1,15 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||||
LlavaConfig, LlavaForConditionalGeneration)
|
AutoProcessor, AutoTokenizer, BatchEncoding)
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||||
@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
"float": torch.float,
|
"float": torch.float,
|
||||||
}
|
}
|
||||||
|
|
||||||
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
|
||||||
|
|
||||||
_EMBEDDING_MODELS = [
|
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
def wrap_device(self, input: any):
|
def wrap_device(self, input: _T) -> _T:
|
||||||
if not is_cpu():
|
if not is_cpu():
|
||||||
return input.to("cuda")
|
return input.to("cuda")
|
||||||
else:
|
else:
|
||||||
@ -163,13 +160,16 @@ class HfRunner:
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
|
*,
|
||||||
|
is_embedding_model: bool = False,
|
||||||
|
is_vision_model: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
if model_name in _EMBEDDING_MODELS:
|
if is_embedding_model:
|
||||||
# Lazy init required for AMD CI
|
# Lazy init required for AMD CI
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
self.model = self.wrap_device(
|
self.model = self.wrap_device(
|
||||||
@ -178,8 +178,13 @@ class HfRunner:
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
).to(dtype=torch_dtype))
|
).to(dtype=torch_dtype))
|
||||||
else:
|
else:
|
||||||
|
if is_vision_model:
|
||||||
|
auto_cls = AutoModelForVision2Seq
|
||||||
|
else:
|
||||||
|
auto_cls = AutoModelForCausalLM
|
||||||
|
|
||||||
self.model = self.wrap_device(
|
self.model = self.wrap_device(
|
||||||
AutoModelForCausalLM.from_pretrained(
|
auto_cls.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def test_models(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
hf_model = hf_runner(model, dtype=dtype)
|
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
del hf_model
|
del hf_model
|
||||||
|
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
|||||||
"""
|
"""
|
||||||
model_id, vision_language_config = model_and_config
|
model_id, vision_language_config = model_and_config
|
||||||
|
|
||||||
hf_model = hf_runner(model_id, dtype=dtype)
|
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
|
||||||
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
images=hf_images)
|
images=hf_images)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user