mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
[V0 Deprecation] Enable the remaining multimodal tests in V1 (#25307)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d88918e4c2
commit
bef180f009
@ -19,6 +19,7 @@ import socket
|
||||
import tempfile
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
|
||||
|
||||
@ -45,14 +46,14 @@ from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -306,6 +307,35 @@ class HfRunner:
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
) -> None:
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
|
||||
with init_ctx:
|
||||
self._init(
|
||||
model_name=model_name,
|
||||
dtype=dtype,
|
||||
model_kwargs=model_kwargs,
|
||||
trust_remote_code=trust_remote_code,
|
||||
is_sentence_transformer=is_sentence_transformer,
|
||||
is_cross_encoder=is_cross_encoder,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
auto_cls=auto_cls,
|
||||
)
|
||||
|
||||
def _init(
|
||||
self,
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
trust_remote_code: bool = True,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
) -> None:
|
||||
model_name = maybe_model_redirect(model_name)
|
||||
self.model_name = model_name
|
||||
@ -714,26 +744,32 @@ class VllmRunner:
|
||||
enable_chunked_prefill: Optional[bool] = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
# Set this to avoid hanging issue
|
||||
default_torch_num_threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.llm = LLM(
|
||||
model=model_name,
|
||||
runner=runner,
|
||||
convert=convert,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
block_size=block_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
**kwargs,
|
||||
)
|
||||
init_ctx = (nullcontext() if default_torch_num_threads is None else
|
||||
set_default_torch_num_threads(default_torch_num_threads))
|
||||
|
||||
with init_ctx:
|
||||
self.llm = LLM(
|
||||
model=model_name,
|
||||
runner=runner,
|
||||
convert=convert,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
block_size=block_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
|
||||
@ -32,13 +32,6 @@ from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs,
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
REQUIRES_V0_MODELS = [
|
||||
# V1 Test: not enough KV cache space in C1.
|
||||
"fuyu",
|
||||
# V1 Test: Deadlock issue when processing mm_inputs
|
||||
"llava-onevision-transformers",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
COMMON_BROADCAST_SETTINGS = {
|
||||
"test_type": VLMTestType.IMAGE,
|
||||
@ -186,8 +179,11 @@ VLM_TEST_SETTINGS = {
|
||||
image_size_factors=[(0.25, 0.5, 1.0)],
|
||||
vllm_runner_kwargs={
|
||||
"model_impl": "transformers",
|
||||
"default_torch_num_threads": 1,
|
||||
},
|
||||
marks=[pytest.mark.core_model],
|
||||
# FIXME: Investigate why the test hangs
|
||||
# when processing the 3rd prompt in vLLM
|
||||
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
|
||||
),
|
||||
"idefics3-transformers": VLMTestInfo(
|
||||
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
|
||||
@ -320,6 +316,7 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
|
||||
num_logprobs=10,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"gemma3": VLMTestInfo(
|
||||
models=["google/gemma-3-4b-it"],
|
||||
@ -861,13 +858,14 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
|
||||
test_type=VLMTestType.IMAGE,
|
||||
create_new_process_for_each_test=False,
|
||||
))
|
||||
def test_single_image_models(tmp_path: PosixPath, model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_single_image_models(
|
||||
tmp_path: PosixPath,
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_single_image_test(
|
||||
tmp_path=tmp_path,
|
||||
@ -886,13 +884,14 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str,
|
||||
test_type=VLMTestType.MULTI_IMAGE,
|
||||
create_new_process_for_each_test=False,
|
||||
))
|
||||
def test_multi_image_models(tmp_path: PosixPath, model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_multi_image_models(
|
||||
tmp_path: PosixPath,
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_multi_image_test(
|
||||
tmp_path=tmp_path,
|
||||
@ -911,13 +910,13 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str,
|
||||
test_type=VLMTestType.EMBEDDING,
|
||||
create_new_process_for_each_test=False,
|
||||
))
|
||||
def test_image_embedding_models(model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_image_embedding_models(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_embedding_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -935,11 +934,13 @@ def test_image_embedding_models(model_type: str,
|
||||
test_type=VLMTestType.VIDEO,
|
||||
create_new_process_for_each_test=False,
|
||||
))
|
||||
def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
|
||||
video_assets: VideoTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_video_models(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
video_assets: VideoTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_video_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -957,11 +958,13 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
test_type=VLMTestType.AUDIO,
|
||||
create_new_process_for_each_test=False,
|
||||
))
|
||||
def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_audio_models(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_audio_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -984,10 +987,7 @@ def test_custom_inputs_models(
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch,
|
||||
):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_custom_inputs_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -1006,13 +1006,14 @@ def test_custom_inputs_models(
|
||||
create_new_process_for_each_test=True,
|
||||
))
|
||||
@create_new_process_for_each_test()
|
||||
def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_single_image_models_heavy(
|
||||
tmp_path: PosixPath,
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_single_image_test(
|
||||
tmp_path=tmp_path,
|
||||
@ -1032,13 +1033,14 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str,
|
||||
create_new_process_for_each_test=True,
|
||||
))
|
||||
@create_new_process_for_each_test()
|
||||
def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_multi_image_models_heavy(
|
||||
tmp_path: PosixPath,
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_multi_image_test(
|
||||
tmp_path=tmp_path,
|
||||
@ -1058,14 +1060,13 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str,
|
||||
create_new_process_for_each_test=True,
|
||||
))
|
||||
@create_new_process_for_each_test()
|
||||
def test_image_embedding_models_heavy(model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_image_embedding_models_heavy(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
image_assets: ImageTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_embedding_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -1083,12 +1084,13 @@ def test_image_embedding_models_heavy(model_type: str,
|
||||
test_type=VLMTestType.VIDEO,
|
||||
create_new_process_for_each_test=True,
|
||||
))
|
||||
def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
video_assets: VideoTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_video_models_heavy(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
video_assets: VideoTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_video_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -1106,12 +1108,13 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
test_type=VLMTestType.AUDIO,
|
||||
create_new_process_for_each_test=True,
|
||||
))
|
||||
def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets, monkeypatch):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
def test_audio_models_heavy(
|
||||
model_type: str,
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_audio_test(
|
||||
model_test_info=model_test_info,
|
||||
@ -1135,10 +1138,7 @@ def test_custom_inputs_models_heavy(
|
||||
test_case: ExpandableVLMTestArgs,
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch,
|
||||
):
|
||||
if model_type in REQUIRES_V0_MODELS:
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||
runners.run_custom_inputs_test(
|
||||
model_test_info=model_test_info,
|
||||
|
||||
@ -12,13 +12,12 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
from ...utils import check_logprobs_close, dummy_hf_overrides
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import StrPath
|
||||
@ -185,47 +184,3 @@ def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str,
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls,expected_ranges",
|
||||
[(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]),
|
||||
(IMG_URLS[1:4], [
|
||||
PlaceholderRange(offset=11, length=266),
|
||||
PlaceholderRange(offset=277, length=1056),
|
||||
PlaceholderRange(offset=1333, length=418)
|
||||
])])
|
||||
def test_multi_modal_placeholders(vllm_runner, image_urls: list[str],
|
||||
expected_ranges: list[PlaceholderRange],
|
||||
local_asset_server, monkeypatch) -> None:
|
||||
local_image_urls = [local_asset_server.url_for(u) for u in image_urls]
|
||||
prompt = _create_engine_inputs_hf(local_image_urls)
|
||||
|
||||
# This placeholder checking test only works with V0 engine
|
||||
# where `multi_modal_placeholders` is returned with `RequestOutput`
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(
|
||||
"mistral-community/pixtral-12b",
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
load_format="dummy",
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.llm.generate(prompt)
|
||||
|
||||
assert len(outputs) == 1, f"{len(outputs)=}"
|
||||
output: RequestOutput = outputs[0]
|
||||
assert hasattr(output,
|
||||
"multi_modal_placeholders"), f"{output.__dict__=}"
|
||||
assert "image" in output.multi_modal_placeholders, \
|
||||
f"{output.multi_modal_placeholders.keys()=}"
|
||||
image_placeholder_ranges: list[
|
||||
PlaceholderRange] = output.multi_modal_placeholders["image"]
|
||||
assert len(image_placeholder_ranges) == len(
|
||||
expected_ranges), f"{image_placeholder_ranges=}"
|
||||
for real_range, expected_range in zip(image_placeholder_ranges,
|
||||
expected_ranges):
|
||||
assert real_range.offset == expected_range.offset, \
|
||||
f"{real_range=} {expected_range=}"
|
||||
assert real_range.length == expected_range.length, \
|
||||
f"{real_range=} {expected_range=}"
|
||||
|
||||
@ -10,7 +10,6 @@ from PIL import Image
|
||||
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
|
||||
PromptVideoInput, VllmRunner)
|
||||
@ -264,8 +263,7 @@ def run_embedding_input_test(
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with set_default_torch_num_threads(1):
|
||||
vllm_model = vllm_runner(
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
@ -277,9 +275,8 @@ def run_embedding_input_test(
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
outputs_per_case_for_original_input = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
|
||||
@ -4,8 +4,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
|
||||
@ -30,19 +28,17 @@ def _run_test(
|
||||
} for _ in range(10)
|
||||
]
|
||||
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
) as vllm_model,
|
||||
):
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype="half",
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
vllm_model.encode(prompt)
|
||||
|
||||
|
||||
|
||||
@ -45,12 +45,15 @@ def run_awq_test(
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(source_model,
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
with vllm_runner(
|
||||
source_model,
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
source_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -59,13 +62,16 @@ def run_awq_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with vllm_runner(quant_model,
|
||||
quantization="awq",
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
with vllm_runner(
|
||||
quant_model,
|
||||
quantization="awq",
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
quant_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -108,12 +114,8 @@ def run_awq_test(
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@torch.inference_mode()
|
||||
def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
|
||||
size_factors, dtype, max_tokens, num_logprobs,
|
||||
monkeypatch) -> None:
|
||||
size_factors, dtype, max_tokens, num_logprobs) -> None:
|
||||
|
||||
# Test V1: this test hangs during setup on single-scale input.
|
||||
# TODO: figure out why and re-enable this on V1.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
run_awq_test(
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
|
||||
@ -5,10 +5,7 @@
|
||||
Run `pytest tests/quantization/test_bitsandbytes.py`.
|
||||
'''
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
@ -131,12 +128,15 @@ def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
|
||||
))
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
enforce_eager=False) as llm:
|
||||
enforce_eager=False,
|
||||
default_torch_num_threads=1) as llm:
|
||||
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens=32,
|
||||
num_logprobs=5)
|
||||
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
with hf_runner(model_name,
|
||||
model_kwargs=hf_model_kwargs,
|
||||
default_torch_num_threads=1) as llm:
|
||||
transformers_outputs = llm.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens=32, num_logprobs=5)
|
||||
check_logprobs_close(
|
||||
@ -174,7 +174,8 @@ def test_4bit_bnb_embedding_model(
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
gpu_memory_utilization=0.5,
|
||||
quantization="bitsandbytes") as vllm_model:
|
||||
quantization="bitsandbytes",
|
||||
default_torch_num_threads=1) as vllm_model:
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||
@ -184,6 +185,7 @@ def test_4bit_bnb_embedding_model(
|
||||
dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs,
|
||||
is_sentence_transformer=True,
|
||||
default_torch_num_threads=1,
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
@ -222,26 +224,22 @@ def validate_generated_texts(hf_runner,
|
||||
with vllm_runner(model_name,
|
||||
quantization=None if pre_quant else 'bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=False) as llm:
|
||||
enforce_eager=False,
|
||||
default_torch_num_threads=1) as llm:
|
||||
|
||||
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if hf_model_kwargs is None:
|
||||
hf_model_kwargs = {}
|
||||
|
||||
# Run with HF runner
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
with hf_runner(model_name,
|
||||
model_kwargs=hf_model_kwargs,
|
||||
default_torch_num_threads=1) as llm:
|
||||
hf_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# Compare the generated strings
|
||||
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
||||
hf_str = hf_log["generated_text"]
|
||||
|
||||
@ -5,7 +5,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -25,19 +24,17 @@ def test_inference(
|
||||
prompt = dict(prompt_token_ids=[1],
|
||||
multi_modal_data=dict(pixel_values=pixel_values,
|
||||
location_coords=location_coords))
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
) as vllm_model,
|
||||
):
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype="half",
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
|
||||
vllm_output = vllm_model.llm.encode(prompt)
|
||||
assert torch.equal(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user