[CI/Build] Simplify model loading for HfRunner (#5251)

This commit is contained in:
Cyrus Leung 2024-06-05 01:09:19 +08:00 committed by GitHub
parent 27208be66e
commit 9ba093b4f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 13 deletions

View File

@ -1,14 +1,15 @@
import contextlib
import gc
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TypeVar
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float": torch.float,
}
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class HfRunner:
def wrap_device(self, input: any):
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
@ -163,13 +160,16 @@ class HfRunner:
self,
model_name: str,
dtype: str = "half",
*,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name in _EMBEDDING_MODELS:
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
@ -178,8 +178,13 @@ class HfRunner:
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
else:
auto_cls = AutoModelForCausalLM
self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,

View File

@ -28,7 +28,7 @@ def test_models(
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
hf_outputs = hf_model.encode(example_prompts)
del hf_model

View File

@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""
model_id, vision_language_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype)
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)