mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 13:03:06 +08:00
[Doc] Consolidate whisper and florence2 examples (#14050)
This commit is contained in:
parent
8994dabc22
commit
fdcc405346
@ -24,50 +24,7 @@ question_per_audio_count = {
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int):
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': "<|audio|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int):
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
|
||||
audio_in_prompt = "".join([
|
||||
f"Audio {idx+1}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
|
||||
])
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# MiniCPM-O
|
||||
def run_minicpmo(question: str, audio_count: int):
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -94,10 +51,71 @@ def run_minicpmo(question: str, audio_count: int):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int):
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
|
||||
audio_in_prompt = "".join([
|
||||
f"Audio {idx+1}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
|
||||
])
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int):
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': "<|audio|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int):
|
||||
assert audio_count == 1, (
|
||||
"Whisper only support single audio input per prompt")
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
|
||||
prompt = "<|startoftranscript|>"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=448,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"ultravox": run_ultravox,
|
||||
"minicpmo": run_minicpmo,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"minicpmo": run_minicpmo
|
||||
"ultravox": run_ultravox,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
|
||||
158
examples/offline_inference/encoder_decoder_multimodal.py
Normal file
158
examples/offline_inference/encoder_decoder_multimodal.py
Normal file
@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
||||
"""
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def run_florence2():
|
||||
# Create a Florence-2 encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="half",
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # implicit prompt with task token
|
||||
"prompt": "<DETAILED_CAPTION>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
},
|
||||
},
|
||||
{ # explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "Describe in detail what is shown in the image.",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("cherry_blossom").pil_image
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "",
|
||||
},
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
|
||||
def run_mllama():
|
||||
# Create a Mllama encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="half",
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # Implicit prompt
|
||||
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image,
|
||||
},
|
||||
},
|
||||
{ # Explicit prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "<|image|>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
|
||||
},
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
|
||||
def run_whisper():
|
||||
# Create a Whisper encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
max_model_len=448,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
dtype="half",
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # Test implicit prompt
|
||||
"prompt": "<|startoftranscript|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
}
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"florence2": run_florence2,
|
||||
"mllama": run_mllama,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
llm, prompts = model_example_map[model]()
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Decoder prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
print("Duration:", duration)
|
||||
print("RPS:", len(prompts) / duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -1,53 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically Florence-2
|
||||
"""
|
||||
# TODO(Isotr0py):
|
||||
# Move to offline_inference/vision_language.py
|
||||
# after porting vision backbone
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
|
||||
# Create a Florence-2 encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # implicit prompt with task token
|
||||
"prompt": "<DETAILED_CAPTION>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
},
|
||||
},
|
||||
{ # explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "Describe in detail what is shown in the image.",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("cherry_blossom").pil_image
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "",
|
||||
},
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
@ -1,61 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Create a Whisper encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="openai/whisper-large-v3",
|
||||
max_model_len=448,
|
||||
max_num_seqs=400,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
kv_cache_dtype="fp8",
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{
|
||||
"prompt": "<|startoftranscript|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
}
|
||||
] * 1024
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Encoder prompt: {encoder_prompt!r}, "
|
||||
f"Decoder prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
print("Duration:", duration)
|
||||
print("RPS:", len(prompts) / duration)
|
||||
@ -748,11 +748,11 @@ def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create full zeros bias for k_proj weight in self-attention layers.
|
||||
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
||||
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
||||
"""
|
||||
for name, weight in weights:
|
||||
if name.endswith(".self_attn.k_proj.weight"):
|
||||
if name.endswith(".k_proj.weight"):
|
||||
bias = torch.zeros(weight.size(0))
|
||||
bias_name = name.replace("weight", "bias")
|
||||
yield from [(name, weight), (bias_name, bias)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user