[CI/Build] Simplify model loading for HfRunner (#5251)

This commit is contained in:
Cyrus Leung 2024-06-05 01:09:19 +08:00 committed by GitHub
parent 27208be66e
commit 9ba093b4f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 13 deletions

View File

@ -1,14 +1,15 @@
import contextlib import contextlib
import gc import gc
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, TypeVar
import pytest import pytest
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
LlavaConfig, LlavaForConditionalGeneration) AutoProcessor, AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float": torch.float, "float": torch.float,
} }
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
class HfRunner: class HfRunner:
def wrap_device(self, input: any): def wrap_device(self, input: _T) -> _T:
if not is_cpu(): if not is_cpu():
return input.to("cuda") return input.to("cuda")
else: else:
@ -163,13 +160,16 @@ class HfRunner:
self, self,
model_name: str, model_name: str,
dtype: str = "half", dtype: str = "half",
*,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None: ) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name self.model_name = model_name
if model_name in _EMBEDDING_MODELS: if is_embedding_model:
# Lazy init required for AMD CI # Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
self.model = self.wrap_device( self.model = self.wrap_device(
@ -178,8 +178,13 @@ class HfRunner:
device="cpu", device="cpu",
).to(dtype=torch_dtype)) ).to(dtype=torch_dtype))
else: else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
else:
auto_cls = AutoModelForCausalLM
self.model = self.wrap_device( self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained( auto_cls.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=True, trust_remote_code=True,

View File

@ -28,7 +28,7 @@ def test_models(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
del hf_model del hf_model

View File

@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
""" """
model_id, vision_language_config = model_and_config model_id, vision_language_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype) hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts, hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens, max_tokens,
images=hf_images) images=hf_images)