mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:18:39 +08:00
114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
def test_idefics_multimodal(
|
|
vllm_runner,
|
|
monkeypatch,
|
|
) -> None:
|
|
if current_platform.is_rocm():
|
|
# ROCm Triton FA does not currently support sliding window attention
|
|
# switch to use ROCm CK FA backend
|
|
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
|
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3",
|
|
runner="pooling",
|
|
task="classify",
|
|
convert="classify",
|
|
load_format="dummy",
|
|
max_model_len=512,
|
|
enforce_eager=True,
|
|
tensor_parallel_size=1,
|
|
disable_log_stats=True,
|
|
dtype="bfloat16") as vllm_model:
|
|
llm = vllm_model.get_llm()
|
|
outputs = llm.classify(prompts)
|
|
for output in outputs:
|
|
assert len(output.outputs.probs) == 2
|
|
|
|
|
|
def update_config(config):
|
|
config.text_config.update({
|
|
"architectures": ["Gemma3ForSequenceClassification"],
|
|
"classifier_from_token": ["A", "B", "C", "D", "E"],
|
|
"method":
|
|
"no_post_processing",
|
|
"id2label": {
|
|
"A": "Chair",
|
|
"B": "Couch",
|
|
"C": "Table",
|
|
"D": "Bed",
|
|
"E": "Cupboard"
|
|
},
|
|
})
|
|
return config
|
|
|
|
|
|
def test_gemma_multimodal(
|
|
vllm_runner,
|
|
monkeypatch,
|
|
) -> None:
|
|
if current_platform.is_rocm():
|
|
# ROCm Triton FA does not currently support sliding window attention
|
|
# switch to use ROCm CK FA backend
|
|
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
|
|
|
messages = [{
|
|
"role":
|
|
"system",
|
|
"content":
|
|
"""
|
|
You are a helpful assistant. You will be given a product description
|
|
which may also include an image. Classify the following product into
|
|
one of the categories:
|
|
|
|
A = chair
|
|
B = couch
|
|
C = table
|
|
D = bed
|
|
E = cupboard
|
|
|
|
You'll answer with exactly one letter (A, B, C, D, or E)."""
|
|
}, {
|
|
"role":
|
|
"user",
|
|
"content": [{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url":
|
|
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
|
|
}
|
|
}, {
|
|
"type": "text",
|
|
"text": "A fine 19th century piece of furniture."
|
|
}]
|
|
}]
|
|
|
|
with vllm_runner(model_name="google/gemma-3-4b-it",
|
|
runner="pooling",
|
|
task="classify",
|
|
convert="classify",
|
|
load_format="auto",
|
|
hf_overrides=update_config,
|
|
override_pooler_config={"pooling_type": "LAST"},
|
|
max_model_len=512,
|
|
enforce_eager=True,
|
|
tensor_parallel_size=1,
|
|
disable_log_stats=True,
|
|
dtype="bfloat16") as vllm_model:
|
|
|
|
llm = vllm_model.get_llm()
|
|
prompts = llm.preprocess_chat(messages)
|
|
|
|
result = llm.classify(prompts)
|
|
assert result[0].outputs.probs[0] > 0.95
|
|
assert all(c < 0.05 for c in result[0].outputs.probs[1:]) |