Convert examples to ruff-format (#18400)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-26 17:57:54 +01:00 committed by GitHub
parent e7523c2e03
commit 27bebcd897
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
83 changed files with 2529 additions and 2405 deletions

View File

@ -17,7 +17,7 @@ repos:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
files: ^(.buildkite|benchmarks)/.* files: ^(.buildkite|benchmarks|examples)/.*
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
hooks: hooks:

View File

@ -6,6 +6,7 @@ with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
import os import os
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
@ -22,7 +23,7 @@ audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = { question_per_audio_count = {
0: "What is 1+1?", 0: "What is 1+1?",
1: "What is recited in the audio?", 1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?" 2: "What sport and what nursery rhyme are referenced?",
} }
@ -72,8 +73,7 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# MiniCPM-O # MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6" model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
@ -82,19 +82,18 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
audio_placeholder = "(<audio>./</audio>)" * audio_count audio_placeholder = "(<audio>./</audio>)" * audio_count
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501 audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
messages = [{ messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
'role': 'user', prompt = tokenizer.apply_chat_template(
'content': f'{audio_placeholder}\n{question}' messages,
}] tokenize=False,
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True,
tokenize=False, chat_template=audio_chat_template,
add_generation_prompt=True, )
chat_template=audio_chat_template)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -113,7 +112,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
# Since the vision-lora and speech-lora co-exist with the base model, # Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights. # we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora") speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)]) placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>" prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
@ -145,15 +144,19 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
audio_in_prompt = "".join([ audio_in_prompt = "".join(
f"Audio {idx+1}: " [
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
]) for idx in range(audio_count)
]
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" prompt = (
"<|im_start|>user\n" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"{audio_in_prompt}{question}<|im_end|>\n" "<|im_start|>user\n"
"<|im_start|>assistant\n") f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -172,19 +175,22 @@ def run_qwen2_5_omni(question: str, audio_count: int):
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
audio_in_prompt = "".join([ audio_in_prompt = "".join(
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
]) )
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" prompt = (
"<|im_start|>user\n" f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"{audio_in_prompt}{question}<|im_end|>\n" "<|im_start|>user\n"
"<|im_start|>assistant\n") f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
@ -196,13 +202,10 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{ messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
'role': 'user', prompt = tokenizer.apply_chat_template(
'content': "<|audio|>\n" * audio_count + question messages, tokenize=False, add_generation_prompt=True
}] )
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
@ -220,8 +223,7 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
# Whisper # Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData: def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, ( assert audio_count == 1, "Whisper only support single audio input per prompt"
"Whisper only support single audio input per prompt")
model_name = "openai/whisper-large-v3-turbo" model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>" prompt = "<|startoftranscript|>"
@ -252,27 +254,33 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'audio language models') "audio language models"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="ultravox", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="ultravox",
parser.add_argument('--num-prompts', choices=model_example_map.keys(),
type=int, help='Huggingface "model_type".',
default=1, )
help='Number of prompts to run.') parser.add_argument(
parser.add_argument("--num-audios", "--num-prompts", type=int, default=1, help="Number of prompts to run."
type=int, )
default=1, parser.add_argument(
choices=[0, 1, 2], "--num-audios",
help="Number of audio items per prompt.") type=int,
parser.add_argument("--seed", default=1,
type=int, choices=[0, 1, 2],
default=None, help="Number of audio items per prompt.",
help="Set the seed when initializing `vllm.LLM`.") )
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
@ -283,29 +291,30 @@ def main(args):
raise ValueError(f"Model type {model} is not supported.") raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios audio_count = args.num_audios
req_data = model_example_map[model](question_per_audio_count[audio_count], req_data = model_example_map[model](
audio_count) question_per_audio_count[audio_count], audio_count
)
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, sampling_params = SamplingParams(
max_tokens=64, temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
mm_data = {} mm_data = {}
if audio_count > 0: if audio_count > 0:
mm_data = { mm_data = {
"audio": [ "audio": [
asset.audio_and_sample_rate asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
for asset in audio_assets[:audio_count]
] ]
} }
@ -315,8 +324,9 @@ def main(args):
# Batch inference # Batch inference
inputs = [inputs] * args.num_prompts inputs = [inputs] * args.num_prompts
# Add LoRA request if applicable # Add LoRA request if applicable
lora_request = (req_data.lora_requests * lora_request = (
args.num_prompts if req_data.lora_requests else None) req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
outputs = llm.generate( outputs = llm.generate(
inputs, inputs,

View File

@ -16,13 +16,16 @@ but ask different questions.
Run: Run:
python examples/offline_inference/automatic_prefix_caching.py python examples/offline_inference/automatic_prefix_caching.py
""" """
import time import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# ruff: noqa: E501 # ruff: noqa: E501
# A prompt containing a large markdown table. The table is randomly generated by GPT-4. # A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ LONG_PROMPT = (
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
+ """
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address | | ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| |-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | | 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
@ -56,6 +59,7 @@ LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables i
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | | 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | | 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
""" """
)
def get_generation_time(llm, sampling_params, prompts): def get_generation_time(llm, sampling_params, prompts):
@ -72,7 +76,7 @@ def get_generation_time(llm, sampling_params, prompts):
def main(): def main():
# set enable_prefix_caching=True to enable APC # set enable_prefix_caching=True to enable APC
llm = LLM(model='lmsys/longchat-13b-16k', enable_prefix_caching=True) llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
sampling_params = SamplingParams(temperature=0, max_tokens=100) sampling_params = SamplingParams(temperature=0, max_tokens=100)
@ -80,8 +84,8 @@ def main():
get_generation_time( get_generation_time(
llm, llm,
sampling_params, sampling_params,
LONG_PROMPT + LONG_PROMPT
"Question: what is the age of John Doe? Your answer: The age of John Doe is ", + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
) )
# Querying the age of Zack Blue # Querying the age of Zack Blue
@ -89,8 +93,8 @@ def main():
get_generation_time( get_generation_time(
llm, llm,
sampling_params, sampling_params,
LONG_PROMPT + LONG_PROMPT
"Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
) )

View File

@ -56,22 +56,12 @@ def main(args: dict):
# In this script, we demonstrate how to pass input to the chat method: # In this script, we demonstrate how to pass input to the chat method:
conversation = [ conversation = [
{ {"role": "system", "content": "You are a helpful assistant"},
"role": "system", {"role": "user", "content": "Hello"},
"content": "You are a helpful assistant" {"role": "assistant", "content": "Hello! How can I assist you today?"},
},
{ {
"role": "user", "role": "user",
"content": "Hello" "content": "Write an essay about the importance of higher education.",
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content":
"Write an essay about the importance of higher education.",
}, },
] ]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False) outputs = llm.chat(conversation, sampling_params, use_tqdm=False)

View File

@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", parser.set_defaults(
task="classify", model="jason9693/Qwen2.5-1.5B-apeach", task="classify", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()
@ -36,10 +36,11 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs probs = output.outputs.probs
probs_trimmed = ((str(probs[:16])[:-1] + probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
", ...]") if len(probs) > 16 else probs) print(
print(f"Prompt: {prompt!r} \n" f"Prompt: {prompt!r} \n"
f"Class Probabilities: {probs_trimmed} (size={len(probs)})") f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
)
print("-" * 60) print("-" * 60)

View File

@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", parser.set_defaults(
task="embed", model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()
@ -36,10 +36,10 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
f"Embeddings: {embeds_trimmed} (size={len(embeds)})") print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print("-" * 60) print("-" * 60)

View File

@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="BAAI/bge-reranker-v2-m3", parser.set_defaults(
task="score", model="BAAI/bge-reranker-v2-m3", task="score", enforce_eager=True
enforce_eager=True) )
return parser.parse_args() return parser.parse_args()

View File

@ -17,12 +17,14 @@ Ray Data provides functionality for:
Learn more about Ray Data's LLM integration: Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html https://docs.ray.io/en/latest/data/working-with-llms.html
""" """
import ray import ray
from packaging.version import Version from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
assert Version(ray.__version__) >= Version( assert Version(ray.__version__) >= Version("2.44.1"), (
"2.44.1"), "Ray version must be at least 2.44.1" "Ray version must be at least 2.44.1"
)
# Uncomment to reduce clutter in stdout # Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False) # ray.init(log_to_driver=False)
@ -53,20 +55,18 @@ config = vLLMEngineProcessorConfig(
vllm_processor = build_llm_processor( vllm_processor = build_llm_processor(
config, config,
preprocess=lambda row: dict( preprocess=lambda row: dict(
messages=[{ messages=[
"role": "system", {"role": "system", "content": "You are a bot that responds with haikus."},
"content": "You are a bot that responds with haikus." {"role": "user", "content": row["text"]},
}, { ],
"role": "user",
"content": row["text"]
}],
sampling_params=dict( sampling_params=dict(
temperature=0.3, temperature=0.3,
max_tokens=250, max_tokens=250,
)), ),
),
postprocess=lambda row: dict( postprocess=lambda row: dict(
answer=row["generated_text"], answer=row["generated_text"],
**row # This will return all the original columns in the dataset. **row, # This will return all the original columns in the dataset.
), ),
) )

View File

@ -50,87 +50,93 @@ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or any other mistral model with function calling ability # or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name, llm = LLM(
tokenizer_mode="mistral", model=model_name,
config_format="mistral", tokenizer_mode="mistral",
load_format="mistral") config_format="mistral",
load_format="mistral",
)
def generate_random_id(length=9): def generate_random_id(length=9):
characters = string.ascii_letters + string.digits characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length)) random_id = "".join(random.choice(characters) for _ in range(length))
return random_id return random_id
# simulate an API that can be called # simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " return (
"partly cloudly, with highs in the 90's.") f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's."
)
tool_functions = {"get_current_weather": get_current_weather} tool_functions = {"get_current_weather": get_current_weather}
tools = [{ tools = [
"type": "function", {
"function": { "type": "function",
"name": "get_current_weather", "function": {
"description": "Get the current weather in a given location", "name": "get_current_weather",
"parameters": { "description": "Get the current weather in a given location",
"type": "object", "parameters": {
"properties": { "type": "object",
"city": { "properties": {
"type": "city": {
"string", "type": "string",
"description": "description": "The city to find the weather for, e.g. 'San Francisco'",
"The city to find the weather for, e.g. 'San Francisco'" },
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
}, },
"state": { "required": ["city", "state", "unit"],
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
}, },
"required": ["city", "state", "unit"] },
}
} }
}] ]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
"Can you tell me what the temperate will be in Dallas, in fahrenheit?" }
}] ]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip() output = outputs[0].outputs[0].text.strip()
# append the assistant message # append the assistant message
messages.append({ messages.append(
"role": "assistant", {
"content": output, "role": "assistant",
}) "content": output,
}
)
# let's now actually parse and execute the model's output simulating an API call by using the # let's now actually parse and execute the model's output simulating an API call by using the
# above defined function # above defined function
tool_calls = json.loads(output) tool_calls = json.loads(output)
tool_answers = [ tool_answers = [
tool_functions[call['name']](**call['arguments']) for call in tool_calls tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
] ]
# append the answer as a tool message and let the LLM give you an answer # append the answer as a tool message and let the LLM give you an answer
messages.append({ messages.append(
"role": "tool", {
"content": "\n\n".join(tool_answers), "role": "tool",
"tool_call_id": generate_random_id(), "content": "\n\n".join(tool_answers),
}) "tool_call_id": generate_random_id(),
}
)
outputs = llm.chat(messages, sampling_params, tools=tools) outputs = llm.chat(messages, sampling_params, tools=tools)

View File

@ -27,6 +27,7 @@ Multi-node:
--master-addr=10.99.48.128 \ --master-addr=10.99.48.128 \
--master-port=13345 --master-port=13345
""" """
import os import os
from time import sleep from time import sleep
@ -36,46 +37,46 @@ from vllm.utils import get_open_port
def parse_args(): def parse_args():
import argparse import argparse
parser = argparse.ArgumentParser(description="Data Parallel Inference") parser = argparse.ArgumentParser(description="Data Parallel Inference")
parser.add_argument("--model", parser.add_argument(
type=str, "--model",
default="ibm-research/PowerMoE-3b", type=str,
help="Model name or path") default="ibm-research/PowerMoE-3b",
parser.add_argument("--dp-size", help="Model name or path",
type=int, )
default=2, parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
help="Data parallel size") parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("--tp-size", parser.add_argument(
type=int, "--node-size", type=int, default=1, help="Total number of nodes"
default=2, )
help="Tensor parallel size") parser.add_argument(
parser.add_argument("--node-size", "--node-rank", type=int, default=0, help="Rank of the current node"
type=int, )
default=1, parser.add_argument(
help="Total number of nodes") "--master-addr", type=str, default="", help="Master node IP address"
parser.add_argument("--node-rank", )
type=int, parser.add_argument("--master-port", type=int, default=0, help="Master node port")
default=0, parser.add_argument(
help="Rank of the current node") "--enforce-eager", action="store_true", help="Enforce eager mode execution."
parser.add_argument("--master-addr", )
type=str, parser.add_argument(
default="", "--trust-remote-code", action="store_true", help="Trust remote code."
help="Master node IP address") )
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action='store_true',
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action='store_true',
help="Trust remote code.")
return parser.parse_args() return parser.parse_args()
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, def main(
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): model,
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
GPUs_per_dp_rank,
enforce_eager,
trust_remote_code,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
@ -110,9 +111,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
# since we are doing data parallel, every rank can have different # since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different # sampling params. here we set different max_tokens for different
# ranks for demonstration. # ranks for demonstration.
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(
top_p=0.95, temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
max_tokens=[16, 20][global_dp_rank % 2]) )
# Create an LLM. # Create an LLM.
llm = LLM( llm = LLM(
@ -130,15 +131,16 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
break break
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " print(
f"Generated text: {generated_text!r}") f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}"
)
# Give engines time to pause their processing loops before exiting. # Give engines time to pause their processing loops before exiting.
sleep(1) sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
dp_size = args.dp_size dp_size = args.dp_size
@ -160,20 +162,29 @@ if __name__ == "__main__":
procs = [] procs = []
for local_dp_rank, global_dp_rank in enumerate( for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
proc = Process(target=main, ):
args=(args.model, dp_size, local_dp_rank, proc = Process(
global_dp_rank, dp_master_ip, dp_master_port, target=main,
tp_size, args.enforce_eager, args=(
args.trust_remote_code)) args.model,
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
tp_size,
args.enforce_eager,
args.trust_remote_code,
),
)
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join(timeout=300) proc.join(timeout=300)
if proc.exitcode is None: if proc.exitcode is None:
print(f"Killing process {proc.pid} that " print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
f"didn't stop within 5 minutes.")
proc.kill() proc.kill()
exit_code = 1 exit_code = 1
elif proc.exitcode: elif proc.exitcode:

View File

@ -22,17 +22,18 @@ def main():
prompts = read_prompts() prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(
enforce_eager=True, model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.8, enforce_eager=True,
max_num_batched_tokens=64, gpu_memory_utilization=0.8,
max_num_seqs=16, max_num_batched_tokens=64,
kv_transfer_config=KVTransferConfig( max_num_seqs=16,
kv_connector="SharedStorageConnector", kv_transfer_config=KVTransferConfig(
kv_role="kv_both", kv_connector="SharedStorageConnector",
kv_connector_extra_config={ kv_role="kv_both",
"shared_storage_path": "local_storage" kv_connector_extra_config={"shared_storage_path": "local_storage"},
})) #, max_model_len=2048, max_num_batched_tokens=2048) ),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance) # 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -20,15 +20,16 @@ def main():
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", llm = LLM(
enforce_eager=True, model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.8, enforce_eager=True,
kv_transfer_config=KVTransferConfig( gpu_memory_utilization=0.8,
kv_connector="SharedStorageConnector", kv_transfer_config=KVTransferConfig(
kv_role="kv_both", kv_connector="SharedStorageConnector",
kv_connector_extra_config={ kv_role="kv_both",
"shared_storage_path": "local_storage" kv_connector_extra_config={"shared_storage_path": "local_storage"},
})) #, max_model_len=2048, max_num_batched_tokens=2048) ),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance) # 1ST generation (prefill instance)
outputs = llm.generate( outputs = llm.generate(

View File

@ -4,6 +4,7 @@ This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and then transfer the KV cache between them. and then transfer the KV cache between them.
""" """
import os import os
import time import time
from multiprocessing import Event, Process from multiprocessing import Event, Process
@ -32,17 +33,21 @@ def run_prefill(prefill_done):
# This instance is the prefill node (kv_producer, rank 0). # This instance is the prefill node (kv_producer, rank 0).
# The number of parallel instances for KV cache transfer is set to 2, # The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector. # as required for PyNcclConnector.
ktc = KVTransferConfig(kv_connector="PyNcclConnector", ktc = KVTransferConfig(
kv_role="kv_producer", kv_connector="PyNcclConnector",
kv_rank=0, kv_role="kv_producer",
kv_parallel_size=2) kv_rank=0,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU. # memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", llm = LLM(
kv_transfer_config=ktc, model="meta-llama/Meta-Llama-3.1-8B-Instruct",
max_model_len=2000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8) max_model_len=2000,
gpu_memory_utilization=0.8,
)
llm.generate(prompts, sampling_params) llm.generate(prompts, sampling_params)
print("Prefill node is finished.") print("Prefill node is finished.")
@ -72,17 +77,21 @@ def run_decode(prefill_done):
# This instance is the decode node (kv_consumer, rank 1). # This instance is the decode node (kv_consumer, rank 1).
# The number of parallel instances for KV cache transfer is set to 2, # The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector. # as required for PyNcclConnector.
ktc = KVTransferConfig(kv_connector="PyNcclConnector", ktc = KVTransferConfig(
kv_role="kv_consumer", kv_connector="PyNcclConnector",
kv_rank=1, kv_role="kv_consumer",
kv_parallel_size=2) kv_rank=1,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU. # memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", llm = LLM(
kv_transfer_config=ktc, model="meta-llama/Meta-Llama-3.1-8B-Instruct",
max_model_len=2000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8) max_model_len=2000,
gpu_memory_utilization=0.8,
)
# Wait for the producer to start the pipe # Wait for the producer to start the pipe
print("Waiting for prefill node to finish...") print("Waiting for prefill node to finish...")
@ -99,8 +108,8 @@ def run_decode(prefill_done):
def main(): def main():
prefill_done = Event() prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done, )) prefill_process = Process(target=run_prefill, args=(prefill_done,))
decode_process = Process(target=run_decode, args=(prefill_done, )) decode_process = Process(target=run_decode, args=(prefill_done,))
# Start prefill node # Start prefill node
prefill_process.start() prefill_process.start()

View File

@ -20,9 +20,7 @@ def load_prompts(dataset_path, num_prompts):
print(f"Error reading dataset: {e}") print(f"Error reading dataset: {e}")
return [] return []
else: else:
prompts = [ prompts = ["The future of AI is", "The president of the United States is"]
"The future of AI is", "The president of the United States is"
]
return prompts[:num_prompts] return prompts[:num_prompts]
@ -33,34 +31,32 @@ def parse_args():
"--dataset", "--dataset",
type=str, type=str,
default="./examples/data/gsm8k.jsonl", default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \ help="downloaded from the eagle repo "
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
)
parser.add_argument(
"--method", type=str, default="eagle", choices=["eagle", "eagle3"]
) )
parser.add_argument("--method",
type=str,
default='eagle',
choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2) parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1) parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true') parser.add_argument("--enforce_eager", action="store_true")
parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--enable_chunked_prefill", action="store_true")
parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0) parser.add_argument("--temp", type=float, default=0)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
if args.method == 'eagle': if args.method == "eagle":
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == 'eagle3': elif args.method == "eagle3":
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else: else:
raise ValueError(f"unknown method: {args.method}") raise ValueError(f"unknown method: {args.method}")
@ -72,11 +68,9 @@ def main():
prompts = load_prompts(args.dataset, args.num_prompts) prompts = load_prompts(args.dataset, args.num_prompts)
prompt_ids = [ prompt_ids = [
tokenizer.apply_chat_template([{ tokenizer.apply_chat_template(
"role": "user", [{"role": "user", "content": prompt}], add_generation_prompt=True
"content": prompt )
}],
add_generation_prompt=True)
for prompt in prompts for prompt in prompts
] ]
@ -102,8 +96,7 @@ def main():
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
outputs = llm.generate(prompt_token_ids=prompt_ids, outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
sampling_params=sampling_params)
# print the generated text # print the generated text
for output in outputs: for output in outputs:
@ -120,19 +113,22 @@ def main():
# accepted # accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1) acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs: for output in outputs:
for step, count in enumerate( for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count acceptance_counts[step] += count
print("-" * 50) print("-" * 50)
print(f"mean acceptance length (including bonus tokens): \ print(
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}") f"mean acceptance length (including bonus tokens): \
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}"
)
print("-" * 50) print("-" * 50)
# print acceptance at each token position # print acceptance at each token position
for i in range(len(acceptance_counts)): for i in range(len(acceptance_counts)):
print(f"acceptance at token {i}:" print(
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") f"acceptance at token {i}:"
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3", parser.set_defaults(
task="embed", model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
trust_remote_code=True) )
return parser.parse_args() return parser.parse_args()
@ -41,11 +41,14 @@ def main(args: Namespace):
print("-" * 60) print("-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
f"Embeddings for text matching: {embeds_trimmed} " print(
f"(size={len(embeds)})") f"Prompt: {prompt!r} \n"
f"Embeddings for text matching: {embeds_trimmed} "
f"(size={len(embeds)})"
)
print("-" * 60) print("-" * 60)

View File

@ -10,9 +10,9 @@ def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments # Set example specific arguments
parser.set_defaults(model="jinaai/jina-embeddings-v3", parser.set_defaults(
task="embed", model="jinaai/jina-embeddings-v3", task="embed", trust_remote_code=True
trust_remote_code=True) )
return parser.parse_args() return parser.parse_args()
@ -39,11 +39,10 @@ def main(args: Namespace):
print("-" * 60) print("-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] + embeds_trimmed = (
", ...]") if len(embeds) > 16 else embeds) (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
print(f"Prompt: {prompt!r} \n" )
f"Embeddings: {embeds_trimmed} " print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
f"(size={len(embeds)})")
print("-" * 60) print("-" * 60)

View File

@ -1,12 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
''' """
Demonstrate prompting of text-to-text Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART encoder/decoder models, specifically BART
''' """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.inputs import (
TokensPrompt, zip_enc_dec_prompts) ExplicitEncoderDecoderPrompt,
TextPrompt,
TokensPrompt,
zip_enc_dec_prompts,
)
def create_prompts(tokenizer): def create_prompts(tokenizer):
@ -18,8 +22,9 @@ def create_prompts(tokenizer):
# - Helpers for building prompts # - Helpers for building prompts
text_prompt_raw = "Hello, my name is" text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is") text_prompt = TextPrompt(prompt="The president of the United States is")
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( tokens_prompt = TokensPrompt(
prompt="The capital of France is")) prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
)
# - Pass a single prompt to encoder/decoder model # - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt); # (implicitly encoder input prompt);
# decoder input prompt is assumed to be None # decoder input prompt is assumed to be None
@ -57,14 +62,19 @@ def create_prompts(tokenizer):
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt # decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances # instances
zipped_prompt_list = zip_enc_dec_prompts( zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'], ["An encoder prompt", "Another encoder prompt"],
['A decoder prompt', 'Another decoder prompt']) ["A decoder prompt", "Another decoder prompt"],
)
# - Let's put all of the above example prompts together into one list # - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM. # which we will pass to the encoder/decoder LLM.
return [ return [
single_text_prompt_raw, single_text_prompt, single_tokens_prompt, single_text_prompt_raw,
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 single_text_prompt,
single_tokens_prompt,
enc_dec_prompt1,
enc_dec_prompt2,
enc_dec_prompt3,
] + zipped_prompt_list ] + zipped_prompt_list
@ -85,10 +95,12 @@ def print_outputs(outputs):
prompt = output.prompt prompt = output.prompt
encoder_prompt = output.encoder_prompt encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Output {i+1}:") print(f"Output {i + 1}:")
print(f"Encoder prompt: {encoder_prompt!r}\n" print(
f"Decoder prompt: {prompt!r}\n" f"Encoder prompt: {encoder_prompt!r}\n"
f"Generated text: {generated_text!r}") f"Decoder prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}"
)
print("-" * 50) print("-" * 50)

View File

@ -3,6 +3,7 @@
This example shows how to use vLLM for running offline inference with This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation. the explicit/implicit prompt format on enc-dec LMMs for text generation.
""" """
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict from dataclasses import asdict
@ -30,18 +31,14 @@ def run_florence2():
) )
prompts = [ prompts = [
{ # implicit prompt with task token { # implicit prompt with task token
"prompt": "<DETAILED_CAPTION>", "prompt": "<DETAILED_CAPTION>",
"multi_modal_data": { "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
"image": ImageAsset("stop_sign").pil_image
},
}, },
{ # explicit encoder/decoder prompt { # explicit encoder/decoder prompt
"encoder_prompt": { "encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.", "prompt": "Describe in detail what is shown in the image.",
"multi_modal_data": { "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
"image": ImageAsset("cherry_blossom").pil_image
},
}, },
"decoder_prompt": "", "decoder_prompt": "",
}, },
@ -63,20 +60,20 @@ def run_mllama():
) )
prompts = [ prompts = [
{ # Implicit prompt { # Implicit prompt
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
"multi_modal_data": { "multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image, "image": ImageAsset("stop_sign").pil_image,
}, },
}, },
{ # Explicit prompt { # Explicit prompt
"encoder_prompt": { "encoder_prompt": {
"prompt": "<|image|>", "prompt": "<|image|>",
"multi_modal_data": { "multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image, "image": ImageAsset("stop_sign").pil_image,
}, },
}, },
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
}, },
] ]
@ -96,13 +93,13 @@ def run_whisper():
) )
prompts = [ prompts = [
{ # Test implicit prompt { # Test implicit prompt
"prompt": "<|startoftranscript|>", "prompt": "<|startoftranscript|>",
"multi_modal_data": { "multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
}, },
}, },
{ # Test explicit encoder/decoder prompt { # Test explicit encoder/decoder prompt
"encoder_prompt": { "encoder_prompt": {
"prompt": "", "prompt": "",
"multi_modal_data": { "multi_modal_data": {
@ -110,7 +107,7 @@ def run_whisper():
}, },
}, },
"decoder_prompt": "<|startoftranscript|>", "decoder_prompt": "<|startoftranscript|>",
} },
] ]
return ModelRequestData( return ModelRequestData(
@ -128,18 +125,23 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for text generation') "vision language models for text generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="mllama", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="mllama",
parser.add_argument("--seed", choices=model_example_map.keys(),
type=int, help='Huggingface "model_type".',
default=None, )
help="Set the seed when initializing `vllm.LLM`.") parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
@ -153,7 +155,8 @@ def main(args):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
@ -179,8 +182,7 @@ def main(args):
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Decoder prompt: {prompt!r}, " print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
duration = time.time() - start duration = time.time() - start

View File

@ -3,6 +3,7 @@
This file demonstrates using the `LLMEngine` This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters. for processing prompts with various sampling parameters.
""" """
import argparse import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
@ -12,24 +13,26 @@ from vllm.utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]: def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters.""" """Create a list of test prompts with their sampling parameters."""
return [ return [
("A robot may not injure a human being", (
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), "A robot may not injure a human being",
("To be or not to be,", SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ),
("What is the meaning of life?", (
SamplingParams(n=2, "To be or not to be,",
temperature=0.8, SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
top_p=0.95, ),
frequency_penalty=0.1)), (
"What is the meaning of life?",
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
print('-' * 50) print("-" * 50)
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
@ -41,7 +44,7 @@ def process_requests(engine: LLMEngine,
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished: if request_output.finished:
print(request_output) print(request_output)
print('-' * 50) print("-" * 50)
def initialize_engine(args: argparse.Namespace) -> LLMEngine: def initialize_engine(args: argparse.Namespace) -> LLMEngine:
@ -52,7 +55,8 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine:
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly') description="Demo on using the LLMEngine class directly"
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
@ -64,6 +68,6 @@ def main(args: argparse.Namespace):
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)

View File

@ -36,22 +36,21 @@ def parse_args():
parser.set_defaults(load_format="sharded_state") parser.set_defaults(load_format="sharded_state")
# Add validation arguments # Add validation arguments
parser.add_argument("--prompt", parser.add_argument(
type=str, "--prompt", type=str, default="Hello, world!", help="Prompt for validation"
default="Hello, world!", )
help="Prompt for validation") parser.add_argument(
parser.add_argument("--max-tokens", "--max-tokens",
type=int, type=int,
default=100, default=100,
help="Maximum number of tokens to generate") help="Maximum number of tokens to generate",
parser.add_argument("--temperature", )
type=float, parser.add_argument(
default=0.7, "--temperature", type=float, default=0.7, help="Sampling temperature"
help="Sampling temperature") )
parser.add_argument("--top-p", parser.add_argument(
type=float, "--top-p", type=float, default=1.0, help="Top-p sampling parameter"
default=1.0, )
help="Top-p sampling parameter")
return parser.parse_args() return parser.parse_args()
@ -60,8 +59,9 @@ def main():
args = parse_args() args = parse_args()
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
print(f"Loading model from {engine_args.model} " print(
f"using format {engine_args.load_format}") f"Loading model from {engine_args.model} using format {engine_args.load_format}"
)
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}") print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
# Load the model using engine args # Load the model using engine args

View File

@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [ return [
# this is an example of using quantization without LoRA # this is an example of using quantization without LoRA
("My name is", (
SamplingParams(temperature=0.0, "My name is",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), None), ),
None,
),
# the next three examples use quantization with LoRA # the next three examples use quantization with LoRA
("my name is", (
SamplingParams(temperature=0.0, "my name is",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), ),
LoRARequest("lora-test-1", 1, lora_path)), LoRARequest("lora-test-1", 1, lora_path),
("The capital of USA is", ),
SamplingParams(temperature=0.0, (
logprobs=1, "The capital of USA is",
prompt_logprobs=1, SamplingParams(
max_tokens=128), temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
LoRARequest("lora-test-2", 1, lora_path)), ),
("The capital of France is", LoRARequest("lora-test-2", 1, lora_path),
SamplingParams(temperature=0.0, ),
logprobs=1, (
prompt_logprobs=1, "The capital of France is",
max_tokens=128), SamplingParams(
LoRARequest("lora-test-3", 1, lora_path)), temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-3", 1, lora_path),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(
test_prompts: list[tuple[str, SamplingParams, engine: LLMEngine,
Optional[LoRARequest]]]): test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0) prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id), engine.add_request(
prompt, str(request_id), prompt, sampling_params, lora_request=lora_request
sampling_params, )
lora_request=lora_request)
request_id += 1 request_id += 1
request_outputs: list[RequestOutput] = engine.step() request_outputs: list[RequestOutput] = engine.step()
@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
print(f"Output: {request_output.outputs[0].text}") print(f"Output: {request_output.outputs[0].text}")
def initialize_engine(model: str, quantization: str, def initialize_engine(
lora_repo: Optional[str]) -> LLMEngine: model: str, quantization: str, lora_repo: Optional[str]
) -> LLMEngine:
"""Initialize the LLMEngine.""" """Initialize the LLMEngine."""
engine_args = EngineArgs(model=model, engine_args = EngineArgs(
quantization=quantization, model=model,
enable_lora=True, quantization=quantization,
max_lora_rank=64, enable_lora=True,
max_loras=4) max_lora_rank=64,
max_loras=4,
)
return LLMEngine.from_engine_args(engine_args) return LLMEngine.from_engine_args(engine_args)
@ -90,32 +98,30 @@ def main():
# QLoRA (https://arxiv.org/abs/2305.14314) # QLoRA (https://arxiv.org/abs/2305.14314)
{ {
"name": "qlora_inference_example", "name": "qlora_inference_example",
'model': "huggyllama/llama-7b", "model": "huggyllama/llama-7b",
'quantization': "bitsandbytes", "quantization": "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b' "lora_repo": "timdettmers/qlora-flan-7b",
}, },
{ {
"name": "AWQ_inference_with_lora_example", "name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
'quantization': "awq", "quantization": "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora' "lora_repo": "jashing/tinyllama-colorist-lora",
}, },
{ {
"name": "GPTQ_inference_with_lora_example", "name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
'quantization': "gptq", "quantization": "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora' "lora_repo": "jashing/tinyllama-colorist-lora",
} },
] ]
for test_config in test_configs: for test_config in test_configs:
print( print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" engine = initialize_engine(
test_config["model"], test_config["quantization"], test_config["lora_repo"]
) )
engine = initialize_engine(test_config['model'], lora_path = snapshot_download(repo_id=test_config["lora_repo"])
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
test_prompts = create_test_prompts(lora_path) test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
@ -125,5 +131,5 @@ def main():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace):
messages = [ messages = [
{ {
"role": "role": "user",
"user",
"content": [ "content": [
{ {"type": "text", "text": prompt},
"type": "text", {"type": "image_url", "image_url": {"url": image_url}},
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
], ],
}, },
] ]
@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace):
messages = [ messages = [
{ {
"role": "role": "user",
"user",
"content": [ "content": [
{ {"type": "text", "text": prompt},
"type": "text", {"type": "image_url", "image_url": {"url": url_1}},
"text": prompt {"type": "image_url", "image_url": {"url": url_2}},
},
{
"type": "image_url",
"image_url": {
"url": url_1
}
},
{
"type": "image_url",
"image_url": {
"url": url_2
}
},
], ],
}, },
{ {
@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"type": "image_url", "image_url": {"url": url_3}},
"type": "image_url",
"image_url": {
"url": url_3
}
},
], ],
}, },
] ]
@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode.") description="Run a demo in simple or advanced mode."
)
parser.add_argument( parser.add_argument(
"mode", "mode",
@ -179,15 +152,18 @@ def parse_args():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument('--format', parser.add_argument(
choices=["mistral", "hf"], "--format",
default="mistral", choices=["mistral", "hf"],
help='Specify the format of the model to load.') default="mistral",
help="Specify the format of the model to load.",
)
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', "--disable-mm-preprocessor-cache",
action='store_true', action="store_true",
help='If True, disables caching of multi-modal preprocessor/mapper.') help="If True, disables caching of multi-modal preprocessor/mapper.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -13,8 +13,9 @@ import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
def time_generation(llm: LLM, prompts: list[str], def time_generation(
sampling_params: SamplingParams, title: str): llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
# Generate texts from the prompts. The output is a list of RequestOutput # Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information. # objects that contain the prompt, generated text, and other information.
# Warmup first # Warmup first
@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str],
end = time.time() end = time.time()
print("-" * 50) print("-" * 50)
print(title) print(title)
print("time: ", print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
@ -38,7 +38,8 @@ def main():
template = ( template = (
"Below is an instruction that describes a task. Write a response " "Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}" "that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n") "\n\n### Response:\n"
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [

View File

@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters. """Create a list of test prompts with their sampling parameters.
@ -26,38 +26,49 @@ def create_test_prompts(
first adapter have finished. first adapter have finished.
""" """
return [ return [
("A robot may not injure a human being", (
SamplingParams(temperature=0.0, "A robot may not injure a human being",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), None), ),
("To be or not to be,", None,
SamplingParams(temperature=0.8, ),
top_k=5, (
presence_penalty=0.2, "To be or not to be,",
max_tokens=128), None), SamplingParams(
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
),
None,
),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(
logprobs=1, temperature=0.0,
prompt_logprobs=1, logprobs=1,
max_tokens=128, prompt_logprobs=1,
stop_token_ids=[32003]), max_tokens=128,
LoRARequest("sql-lora", 1, lora_path)), stop_token_ids=[32003],
),
LoRARequest("sql-lora", 1, lora_path),
),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(
logprobs=1, temperature=0.0,
prompt_logprobs=1, logprobs=1,
max_tokens=128, prompt_logprobs=1,
stop_token_ids=[32003]), max_tokens=128,
LoRARequest("sql-lora2", 2, lora_path)), stop_token_ids=[32003],
),
LoRARequest("sql-lora2", 2, lora_path),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(
test_prompts: list[tuple[str, SamplingParams, engine: LLMEngine,
Optional[LoRARequest]]]): test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine,
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0) prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id), engine.add_request(
prompt, str(request_id), prompt, sampling_params, lora_request=lora_request
sampling_params, )
lora_request=lora_request)
request_id += 1 request_id += 1
request_outputs: list[RequestOutput] = engine.step() request_outputs: list[RequestOutput] = engine.step()
@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine:
# numbers will cause higher memory usage. If you know that all LoRAs will # numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible. # use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache. # max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", engine_args = EngineArgs(
enable_lora=True, model="meta-llama/Llama-2-7b-hf",
max_loras=1, enable_lora=True,
max_lora_rank=8, max_loras=1,
max_cpu_loras=2, max_lora_rank=8,
max_num_seqs=256) max_cpu_loras=2,
max_num_seqs=256,
)
return LLMEngine.from_engine_args(engine_args) return LLMEngine.from_engine_args(engine_args)
@ -105,5 +117,5 @@ def main():
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -30,7 +30,8 @@ def main():
# The device argument can be either unspecified for automated detection, # The device argument can be either unspecified for automated detection,
# or explicitly assigned. # or explicitly assigned.
device="neuron", device="neuron",
tensor_parallel_size=2) tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -24,7 +24,7 @@ llm = LLM(
speculative_config={ speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"max_model_len": 2048 "max_model_len": 2048,
}, },
max_num_seqs=4, max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as # The max_model_len and block_size arguments are required to be same as
@ -40,7 +40,7 @@ llm = LLM(
tensor_parallel_size=32, tensor_parallel_size=32,
override_neuron_config={ override_neuron_config={
"enable_eagle_speculation": True, "enable_eagle_speculation": True,
"enable_fused_speculation": True "enable_fused_speculation": True,
}, },
) )

View File

@ -5,12 +5,12 @@ import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets. # creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets. # creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
# Quantizes neuron model weight to int8 , # Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype. # The default config for quantization is int8 dtype.
os.environ['NEURON_QUANT_DTYPE'] = "s8" os.environ["NEURON_QUANT_DTYPE"] = "s8"
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
@ -44,7 +44,8 @@ def main():
override_neuron_config={ override_neuron_config={
"cast_logits_dtype": "bfloat16", "cast_logits_dtype": "bfloat16",
}, },
tensor_parallel_size=2) tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -19,9 +19,9 @@ prompts = [
def config_buckets(): def config_buckets():
"""Configure context length and token gen buckets.""" """Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets. # creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets. # creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
def initialize_model(): def initialize_model():
@ -31,7 +31,7 @@ def initialize_model():
speculative_config={ speculative_config={
"model": "openlm-research/open_llama_3b", "model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4, "num_speculative_tokens": 4,
"max_model_len": 2048 "max_model_len": 2048,
}, },
max_num_seqs=4, max_num_seqs=4,
max_model_len=2048, max_model_len=2048,
@ -60,5 +60,5 @@ def main():
process_requests(model, sampling_params) process_requests(model, sampling_params)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -16,7 +16,8 @@ prefix = (
"teaching role. They have 5 years of previous teaching experience " "teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience " "as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill " "in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ") "the following paragraph: "
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
@ -58,9 +59,11 @@ def main():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled. # Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m", prefix_cached_llm = LLM(
enable_prefix_caching=True, model="facebook/opt-125m",
gpu_memory_utilization=0.4) enable_prefix_caching=True,
gpu_memory_utilization=0.4,
)
# Warmup so that the shared prompt's KV cache is computed. # Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params) prefix_cached_llm.generate(generating_prompts[0], sampling_params)
@ -81,10 +84,12 @@ def main():
print("-" * 50) print("-" * 50)
# Compare the results and display the speedup # Compare the results and display the speedup
generated_same = all([ generated_same = all(
regular_generated_texts[i] == cached_generated_texts[i] [
for i in range(len(prompts)) regular_generated_texts[i] == cached_generated_texts[i]
]) for i in range(len(prompts))
]
)
print(f"Generated answers are the same: {generated_same}") print(f"Generated answers are the same: {generated_same}")

View File

@ -16,7 +16,8 @@ The requirements for running this script are:
Run the example: Run the example:
python prithvi_geospatial_mae.py python prithvi_geospatial_mae.py
""" # noqa: E501 """ # noqa: E501
import argparse import argparse
import datetime import datetime
import os import os
@ -110,77 +111,67 @@ model_config = """{
# Temporarily creating the "config.json" for the model. # Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF # This is going to disappear once the correct config.json is available on HF
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"), with open(
'w') as config_file: os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config) config_file.write(model_config)
datamodule_config = { datamodule_config = {
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
'batch_size': "batch_size": 16,
16, "constant_scale": 0.0001,
'constant_scale': "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
0.0001, "drop_last": True,
'data_root': "no_data_replace": 0.0,
'/dccstor/geofm-finetuning/datasets/sen1floods11', "no_label_replace": -1,
'drop_last': "num_workers": 8,
True, "test_transform": [
'no_data_replace': albumentations.Resize(
0.0, always_apply=False, height=448, interpolation=1, p=1, width=448
'no_label_replace': ),
-1, albumentations.pytorch.ToTensorV2(
'num_workers': transpose_mask=False, always_apply=True, p=1.0
8, ),
'test_transform': [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0)
], ],
} }
class PrithviMAE: class PrithviMAE:
def __init__(self): def __init__(self):
print("Initializing PrithviMAE model") print("Initializing PrithviMAE model")
self.model = LLM(model=os.path.join(os.path.dirname(__file__), self.model = LLM(
"./model"), model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True, skip_tokenizer_init=True,
dtype="float32") dtype="float32",
)
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############") print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure # merge the inputs into one data structure
mm_data = { mm_data = {
"pixel_values": "pixel_values": torch.empty(0) if input_data is None else input_data,
torch.empty(0) if input_data is None else input_data, "location_coords": torch.empty(0)
"location_coords": if location_coords is None
torch.empty(0) if location_coords is None else location_coords else location_coords,
} }
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False) outputs = self.model.encode(prompt, use_tqdm=False)
print( print("################ Inference done (it took seconds) ##############")
"################ Inference done (it took seconds) ##############"
)
return outputs[0].outputs.data return outputs[0].outputs.data
def generate_datamodule(): def generate_datamodule():
datamodule = Sen1Floods11NonGeoDataModule( datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config['data_root'], data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"], batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"], num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"], bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"], drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform" test_transform=datamodule_config["test_transform"],
""]) )
return datamodule return datamodule
@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels):
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
min_value = OFFSET min_value = OFFSET
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
1)
# No data as zeros # No data as zeros
orig_img[~valid_mask] = 0 orig_img[~valid_mask] = 0
@ -300,18 +290,21 @@ def load_example(
location_coords.append(coords) location_coords.append(coords)
try: try:
match = re.search(r'(\d{7,8}T\d{6})', file) match = re.search(r"(\d{7,8}T\d{6})", file)
if match: if match:
year = int(match.group(1)[:4]) year = int(match.group(1)[:4])
julian_day = match.group(1).split('T')[0][4:] julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3: if len(julian_day) == 3:
julian_day = int(julian_day) julian_day = int(julian_day)
else: else:
julian_day = datetime.datetime.strptime( julian_day = (
julian_day, '%m%d').timetuple().tm_yday datetime.datetime.strptime(julian_day, "%m%d")
.timetuple()
.tm_yday
)
temporal_coords.append([year, julian_day]) temporal_coords.append([year, julian_day])
except Exception as e: except Exception as e:
print(f'Could not extract timestamp for {file} ({e})') print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") imgs = np.moveaxis(imgs, -1, 0).astype("float32")
@ -320,50 +313,44 @@ def load_example(
return imgs, temporal_coords, location_coords, metas return imgs, temporal_coords, location_coords, metas
def run_model(input_data, def run_model(
temporal_coords, input_data,
location_coords, temporal_coords,
model, location_coords,
datamodule, model,
img_size, datamodule,
lightning_model=None): img_size,
lightning_model=None,
):
# Reflect pad if not divisible by img_size # Reflect pad if not divisible by img_size
original_h, original_w = input_data.shape[-2:] original_h, original_w = input_data.shape[-2:]
pad_h = (img_size - (original_h % img_size)) % img_size pad_h = (img_size - (original_h % img_size)) % img_size
pad_w = (img_size - (original_w % img_size)) % img_size pad_w = (img_size - (original_w % img_size)) % img_size
input_data = np.pad(input_data, input_data = np.pad(
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
mode="reflect") )
# Build sliding window # Build sliding window
batch_size = 1 batch_size = 1
batch = torch.tensor(input_data, device="cpu") batch = torch.tensor(input_data, device="cpu")
windows = (batch.unfold(3, img_size, windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
img_size).unfold(4, img_size, img_size))
h1, w1 = windows.shape[3:5] h1, w1 = windows.shape[3:5]
windows = rearrange(windows, windows = rearrange(
"b c t h1 w1 h w -> (b h1 w1) c t h w", windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
h=img_size, )
w=img_size)
# Split into batches if number of windows > batch_size # Split into batches if number of windows > batch_size
num_batches = windows.shape[0] // batch_size if windows.shape[ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0) windows = torch.tensor_split(windows, num_batches, dim=0)
if torch.cuda.is_available(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device('cuda')
else:
device = torch.device('cpu')
if temporal_coords: if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
device=device).unsqueeze(0)
else: else:
temporal_coords = None temporal_coords = None
if location_coords: if location_coords:
location_coords = torch.tensor(location_coords[0], location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
device=device).unsqueeze(0)
else: else:
location_coords = None location_coords = None
@ -371,26 +358,24 @@ def run_model(input_data,
pred_imgs = [] pred_imgs = []
for x in windows: for x in windows:
# Apply standardization # Apply standardization
x = datamodule.test_transform( x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
image=x.squeeze().numpy().transpose(1, 2, 0)) x = datamodule.aug(x)["image"]
x = datamodule.aug(x)['image']
with torch.no_grad(): with torch.no_grad():
x = x.to(device) x = x.to(device)
pred = model.run(x, location_coords=location_coords) pred = model.run(x, location_coords=location_coords)
if lightning_model: if lightning_model:
pred_lightning = lightning_model( pred_lightning = lightning_model(
x, x, temporal_coords=temporal_coords, location_coords=location_coords
temporal_coords=temporal_coords, )
location_coords=location_coords)
pred_lightning = pred_lightning.output.detach().cpu() pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning): if not torch.equal(pred, pred_lightning):
print("Inference output is not equal") print("Inference output is not equal")
y_hat = pred.argmax(dim=1) y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), y_hat = torch.nn.functional.interpolate(
size=img_size, y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
mode="nearest") )
pred_imgs.append(y_hat) pred_imgs.append(y_hat)
@ -437,8 +422,7 @@ def parse_args():
default=[1, 2, 3, 8, 11, 12], default=[1, 2, 3, 8, 11, 12],
type=int, type=int,
nargs="+", nargs="+",
help= help="0-based indices of the six Prithvi channels to be selected from the "
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.", "input. By default selects [1,2,3,8,11,12] for S2L1C data.",
) )
parser.add_argument( parser.add_argument(
@ -478,17 +462,18 @@ def main(
# Running model ------------------------------------------------------------ # Running model ------------------------------------------------------------
channels = [ channels = [
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"] datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB ] # BGR -> RGB
pred = run_model(input_data, temporal_coords, location_coords, model_obj, pred = run_model(
datamodule, img_size) input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
)
# Save pred # Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join( pred_file = os.path.join(
output_dir, output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") )
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
# Save image + pred # Save image + pred
@ -502,13 +487,13 @@ def main(
channels=channels, channels=channels,
) )
pred[pred == 0.] = np.nan pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred = rgb_orig * 0.7 + pred * 0.3
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
img_pred_file = os.path.join( img_pred_file = os.path.join(
output_dir, output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") )
save_geotiff( save_geotiff(
image=_convert_np_uint8(img_pred), image=_convert_np_uint8(img_pred),
output_path=img_pred_file, output_path=img_pred_file,
@ -518,8 +503,9 @@ def main(
# Save image rgb # Save image rgb
if rgb_outputs: if rgb_outputs:
rgb_file = os.path.join( rgb_file = os.path.join(
output_dir, "original_rgb_" output_dir,
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff") f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
)
save_geotiff( save_geotiff(
image=_convert_np_uint8(rgb_orig), image=_convert_np_uint8(rgb_orig),
output_path=rgb_file, output_path=rgb_file,
@ -528,7 +514,6 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(**vars(args)) main(**vars(args))

View File

@ -44,8 +44,11 @@ def get_dtype(dtype: str):
OutputLen_NumReqs_Map: TypeAlias = dict[int, int] OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
-> OutputLen_NumReqs_Map:
def compute_request_output_lengths(
batch_size: int, step_requests: list[int]
) -> OutputLen_NumReqs_Map:
""" """
Given the number of requests, batch_size, and the number of requests Given the number of requests, batch_size, and the number of requests
that each engine-step should process, step_requests, determine the that each engine-step should process, step_requests, determine the
@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
output_length -= 1 output_length -= 1
# sanity checks. # sanity checks.
assert sum(ol_nr.values()) == batch_size, \ assert sum(ol_nr.values()) == batch_size, (
("Number of requests in output-length assignment does not match " "Number of requests in output-length assignment does not match "
f"batch-size.\n batch size {batch_size} - " f"batch-size.\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}") f"step requests {step_requests} - assignments {ol_nr}"
)
# Check that the output-length is in [1, num-steps]. Output length must be # Check that the output-length is in [1, num-steps]. Output length must be
# at least 1 as all requests must participate in the prefill-step. # at least 1 as all requests must participate in the prefill-step.
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \ assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), (
("Output lengths of requests should be in range " "Output lengths of requests should be in range "
f"[1, num-engine-steps].\n batch size {batch_size} - " f"[1, num-engine-steps].\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}") f"step requests {step_requests} - assignments {ol_nr}"
)
return ol_nr return ol_nr
@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
# that their output lengths must be equal to num_engine_steps. # that their output lengths must be equal to num_engine_steps.
return [context.batch_size] * context.num_steps return [context.batch_size] * context.num_steps
assert context.complete_num_requests_per_step and \ assert (
context.complete_num_requests_per_step > 0, \ context.complete_num_requests_per_step
(f"Expected a positive complete_num_requests_per_step argument." and context.complete_num_requests_per_step > 0
f"Instead got {context.complete_num_requests_per_step}") ), (
f"Expected a positive complete_num_requests_per_step argument."
f"Instead got {context.complete_num_requests_per_step}"
)
# We start dropping after the first decode step. # We start dropping after the first decode step.
step_requests = [ step_requests = [
@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
return step_requests return step_requests
def run_profile(context: ProfileContext, csv_output: Optional[str], def run_profile(
json_output: Optional[str]): context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]
):
print("Run profile with:") print("Run profile with:")
for key, value in asdict(context).items(): for key, value in asdict(context).items():
print(f" {key} = {value}") print(f" {key} = {value}")
@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
requests_per_step: list[int] = determine_requests_per_step(context) requests_per_step: list[int] = determine_requests_per_step(context)
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
context.batch_size, requests_per_step) context.batch_size, requests_per_step
)
num_steps_to_profile: int = len(requests_per_step) num_steps_to_profile: int = len(requests_per_step)
max_output_len: int = max(ol_nr.keys()) max_output_len: int = max(ol_nr.keys())
@ -186,7 +196,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
top_p=0.95, top_p=0.95,
# max_tokens is set on a per-request basis. # max_tokens is set on a per-request basis.
max_tokens=None, max_tokens=None,
ignore_eos=True) ignore_eos=True,
)
# Create LLM # Create LLM
llm = LLM(**asdict(context.engine_args)) llm = LLM(**asdict(context.engine_args))
@ -199,31 +210,37 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
max_num_seqs = scheduler_config.max_num_seqs max_num_seqs = scheduler_config.max_num_seqs
if batch_size * prompt_len > max_num_batched_tokens: if batch_size * prompt_len > max_num_batched_tokens:
print(f"ERROR: chosen batch_size * prompt_len " print(
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " f"ERROR: chosen batch_size * prompt_len "
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
f"and therefore cannot be run in a single profile step, please " f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
f"choose a smaller batch size or prompt length, or increase " f"and therefore cannot be run in a single profile step, please "
f"--max-num-batched-tokens") f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens"
)
sys.exit(-1) sys.exit(-1)
if batch_size > max_num_seqs: if batch_size > max_num_seqs:
print( print(
f"ERROR: chosen batch_size ({batch_size}) is larger than " f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size") f"single profile step, please choose a smaller batch size"
)
sys.exit(-1) sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ", print(
llm.llm_engine.model_config.max_model_len) "llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len,
)
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " print(
f"{max_output_len} = {prompt_len + max_output_len}) is larger " f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
f"than the model's max_model_len ({max_model_len}), please " f"{max_output_len} = {prompt_len + max_output_len}) is larger "
f"choose a smaller prompt_len or max_output_len, or increase " f"than the model's max_model_len ({max_model_len}), please "
f"--max-model-len") f"choose a smaller prompt_len or max_output_len, or increase "
f"--max-model-len"
)
sys.exit(-1) sys.exit(-1)
def add_requests(): def add_requests():
def get_output_len_generator() -> Generator[int, Any, Any]: def get_output_len_generator() -> Generator[int, Any, Any]:
for output_len, num_reqs in ol_nr.items(): for output_len, num_reqs in ol_nr.items():
for _ in range(num_reqs): for _ in range(num_reqs):
@ -234,13 +251,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
sampling_params.max_tokens = next(output_len_generator) sampling_params.max_tokens = next(output_len_generator)
assert isinstance(sampling_params.max_tokens, int) assert isinstance(sampling_params.max_tokens, int)
prompt_token_ids = torch.randint(llm.get_tokenizer().vocab_size, prompt_token_ids = torch.randint(
size=(prompt_len, )).tolist() llm.get_tokenizer().vocab_size, size=(prompt_len,)
).tolist()
llm.llm_engine.add_request( llm.llm_engine.add_request(
request_id=f"seq{i}", request_id=f"seq{i}",
prompt={'prompt_token_ids': prompt_token_ids}, prompt={"prompt_token_ids": prompt_token_ids},
params=sampling_params) params=sampling_params,
)
def abort_requests(): def abort_requests():
for i in range(batch_size): for i in range(batch_size):
@ -261,10 +280,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
decode_profs = [] decode_profs = []
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
num_running_seqs = llm.llm_engine.scheduler[ num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups()
0].get_num_unfinished_seq_groups() with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof:
with layerwise_profile(
num_running_seqs=num_running_seqs) as decode_prof:
llm.llm_engine.step() llm.llm_engine.step()
decode_profs.append(decode_prof) decode_profs.append(decode_prof)
@ -274,8 +291,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
LINE_WIDTH = 80 LINE_WIDTH = 80
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= Prefill Model Table " print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})")
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
prefill_results.print_model_table() prefill_results.print_model_table()
@ -283,16 +299,17 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if has_decode: if has_decode:
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= First Decode Step Model Table " print(
f"(prompt_len={prompt_len}, batch_size={batch_size})") f"= First Decode Step Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})"
)
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
decode_results_list[0].print_model_table() decode_results_list[0].print_model_table()
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= Prefill Summary Table " print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})")
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
prefill_results.print_summary_table() prefill_results.print_summary_table()
@ -300,25 +317,32 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if has_decode: if has_decode:
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= First Decode Step Summary Table " print(
f"(prompt_len={prompt_len}, batch_size={batch_size})") f"= First Decode Step Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})"
)
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
decode_results_list[0].print_summary_table() decode_results_list[0].print_summary_table()
if csv_output: if csv_output:
csv_filename_base = csv_output[:-4] \ csv_filename_base = (
if csv_output.endswith('.csv') else csv_output csv_output[:-4] if csv_output.endswith(".csv") else csv_output
)
prefill_results.export_model_stats_table_csv( prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv") csv_filename_base + "_prefill_model_table.csv"
)
prefill_results.export_summary_stats_table_csv( prefill_results.export_summary_stats_table_csv(
csv_filename_base + "_prefill_summary_table.csv") csv_filename_base + "_prefill_summary_table.csv"
)
if has_decode: if has_decode:
decode_results_list[0].export_model_stats_table_csv(\ decode_results_list[0].export_model_stats_table_csv(
csv_filename_base + "_decode_model_table.csv") csv_filename_base + "_decode_model_table.csv"
)
decode_results_list[0].export_summary_stats_table_csv( decode_results_list[0].export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv") csv_filename_base + "_decode_summary_table.csv"
)
if json_output: if json_output:
cuda_devices = [ cuda_devices = [
@ -332,7 +356,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
"torch_version": f"{torch.__version__}", "torch_version": f"{torch.__version__}",
"torch_cuda_version": f"{torch.version.cuda}", "torch_cuda_version": f"{torch.version.cuda}",
"cuda_devices": f"{cuda_devices}", "cuda_devices": f"{cuda_devices}",
**asdict(context) **asdict(context),
}, },
"prefill": prefill_results.convert_stats_to_dict(), "prefill": prefill_results.convert_stats_to_dict(),
} }
@ -342,8 +366,9 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
# Add .json to json_output filename if it doesn't exist already. # Add .json to json_output filename if it doesn't exist already.
json_output_file = json_output if json_output.endswith( json_output_file = (
'.json') else json_output + '.json' json_output if json_output.endswith(".json") else json_output + ".json"
)
with open(json_output_file, "w+") as f: with open(json_output_file, "w+") as f:
json.dump(json_dict, f, indent=2) json.dump(json_dict, f, indent=2)
pass pass
@ -351,16 +376,21 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if context.save_chrome_traces_folder is not None: if context.save_chrome_traces_folder is not None:
os.makedirs(context.save_chrome_traces_folder, exist_ok=True) os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
prefill_prof.profiler.export_chrome_trace( prefill_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + "/prefill.json") context.save_chrome_traces_folder + "/prefill.json"
)
for idx, decode_prof in enumerate(decode_profs): for idx, decode_prof in enumerate(decode_profs):
decode_prof.profiler.export_chrome_trace( decode_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json") context.save_chrome_traces_folder + f"/decode_{idx + 1}.json"
print("Traces saved as prefill.json and decode_1.json, etc." )
f" in folder {context.save_chrome_traces_folder}") print(
"Traces saved as prefill.json and decode_1.json, etc."
f" in folder {context.save_chrome_traces_folder}"
)
def parse_args(): def parse_args():
parser = FlexibleArgumentParser(description=""" parser = FlexibleArgumentParser(
description="""
Profile a model Profile a model
example: example:
@ -384,7 +414,8 @@ Profile a model
--output-directory profile_breakdown --plot-metric pct_cuda_time --output-directory profile_breakdown --plot-metric pct_cuda_time
``` ```
""", """,
formatter_class=RawTextHelpFormatter) formatter_class=RawTextHelpFormatter,
)
parser.add_argument( parser.add_argument(
"--csv", "--csv",
type=str, type=str,
@ -393,59 +424,68 @@ Profile a model
"filename, will create <filename>_prefill_model_table.csv, " "filename, will create <filename>_prefill_model_table.csv, "
"<filename>_prefill_summary_table.csv, " "<filename>_prefill_summary_table.csv, "
"<filename>_decode_model_table.csv, and " "<filename>_decode_model_table.csv, and "
"<filename>_decode_summary_table.csv") "<filename>_decode_summary_table.csv",
)
parser.add_argument( parser.add_argument(
"--json", "--json",
type=str, type=str,
default=None, default=None,
help="Export the results as a json file. This should be the filename") help="Export the results as a json file. This should be the filename",
parser.add_argument("--save-chrome-traces-folder", )
type=str, parser.add_argument(
help="Save chrome traces for the prefill and decode " "--save-chrome-traces-folder",
"will save traces as prefill.json and decode_1.json, " type=str,
"etc. inside this folder") help="Save chrome traces for the prefill and decode "
"will save traces as prefill.json and decode_1.json, "
"etc. inside this folder",
)
parser.add_argument( parser.add_argument(
"--prompt-len", "--prompt-len",
type=int, type=int,
default=PROMPT_LEN_DEFAULT, default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched " help=f"Length of the random prompt to use when profiling, all batched "
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}",
parser.add_argument("--batch-size", )
type=int, parser.add_argument(
default=BATCH_SIZE_DEFAULT, "--batch-size",
help=f"Number of requests to run as a single batch, " type=int,
f"default={BATCH_SIZE_DEFAULT}") default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}",
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
run_num_steps_parser = subparsers.add_parser( run_num_steps_parser = subparsers.add_parser(
"run_num_steps", "run_num_steps", help="This variation profiles n engine.step() invocations."
help="This variation profiles n engine.step() invocations.") )
run_num_steps_parser.add_argument( run_num_steps_parser.add_argument(
'-n', "-n",
'--num-steps', "--num-steps",
type=int, type=int,
help="Number of engine steps to profile.\n" help="Number of engine steps to profile.\n"
"Setting it to 1, profiles only the prefill step.\n" "Setting it to 1, profiles only the prefill step.\n"
"Setting it to 2, profiles the prefill and first decode step\n" "Setting it to 2, profiles the prefill and first decode step\n"
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
"and so on ...") "and so on ...",
)
run_to_completion_parser = subparsers.add_parser( run_to_completion_parser = subparsers.add_parser(
"run_to_completion", "run_to_completion",
help="This variation profiles all the engine.step() invocations" help="This variation profiles all the engine.step() invocations"
"until the engine exhausts all submitted requests.") "until the engine exhausts all submitted requests.",
)
run_to_completion_parser.add_argument( run_to_completion_parser.add_argument(
'-n', "-n",
'--complete-num-requests-per-step', "--complete-num-requests-per-step",
type=int, type=int,
help= help="Complete complete_num_requests_per_step requests every decode step."
"Complete complete_num_requests_per_step requests every decode step."
"For e.g., with batch_size 128 and complete_num_requests_per_step 32," "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
"the profiler is run for 6 engine steps, with the steps processing, " "the profiler is run for 6 engine steps, with the steps processing, "
"128, 128, 96, 64, 32, 1 requests respectively.\n" "128, 128, 96, 64, 32, 1 requests respectively.\n"
"Note that we tack-on a one-request step at the end as it is often " "Note that we tack-on a one-request step at the end as it is often "
"useful.") "useful.",
)
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
@ -459,7 +499,8 @@ def main(args):
k: v k: v
for k, v in vars(args).items() for k, v in vars(args).items()
if k in inspect.signature(ProfileContext).parameters if k in inspect.signature(ProfileContext).parameters
}) },
)
run_profile(context, csv_output=args.csv, json_output=args.json) run_profile(context, csv_output=args.csv, json_output=args.json)

View File

@ -31,18 +31,16 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len, max_tokens=args.output_len,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(
size=(args.batch_size, 10000, size=(args.batch_size, args.input_len)
args.input_len)) )
dummy_prompts: list[PromptType] = [{ dummy_prompts: list[PromptType] = [
"prompt_token_ids": batch {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
} for batch in dummy_prompt_token_ids.tolist()] ]
def run_to_completion(): def run_to_completion():
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(dummy_prompts, llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
latency = end_time - start_time latency = end_time - start_time
return latency return latency
@ -58,10 +56,9 @@ def main(args: argparse.Namespace):
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
# Enable tracing on server # Enable tracing on server
xp.trace_detached("localhost:9012", xp.trace_detached(
profile_dir, "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
delay_ms=DELAY_MS, )
duration_ms=DURATION_MS)
if DELAY_MS == 0: if DELAY_MS == 0:
time.sleep(1.0) time.sleep(1.0)
profile_latencies = [] profile_latencies = []
@ -72,30 +69,36 @@ def main(args: argparse.Namespace):
return return
if __name__ == '__main__': if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of ' description="Benchmark the latency of processing a single batch of "
'requests till completion.') "requests till completion."
parser.add_argument('--input-len', type=int, default=32) )
parser.add_argument('--output-len', type=int, default=128) parser.add_argument("--input-len", type=int, default=32)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument("--output-len", type=int, default=128)
parser.add_argument('--num-iters-warmup', parser.add_argument("--batch-size", type=int, default=8)
type=int,
default=5,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=1,
help='Number of iterations to run for profiling.')
parser.add_argument( parser.add_argument(
'--profile-result-dir', "--num-iters-warmup",
type=int,
default=5,
help="Number of iterations to run for warmup.",
)
parser.add_argument(
"--num-iters",
type=int,
default=1,
help="Number of iterations to run for profiling.",
)
parser.add_argument(
"--profile-result-dir",
type=str, type=str,
default="profiles", default="profiles",
help= help=(
('path to save the pytorch profiler output. Can be visualized ' "path to save the pytorch profiler output. Can be visualized "
'with ui.perfetto.dev or Tensorboard ' "with ui.perfetto.dev or Tensorboard "
'(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).' "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
)) ),
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()

View File

@ -18,8 +18,7 @@ Run:
""" """
import torch import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
PreTrainedTokenizer)
from vllm import LLM from vllm import LLM
@ -32,27 +31,29 @@ def init_tokenizer_and_llm(model_name: str):
return tokenizer, embedding_layer, llm return tokenizer, embedding_layer, llm
def get_prompt_embeds(chat: list[dict[str, def get_prompt_embeds(
str]], tokenizer: PreTrainedTokenizer, chat: list[dict[str, str]],
embedding_layer: torch.nn.Module): tokenizer: PreTrainedTokenizer,
token_ids = tokenizer.apply_chat_template(chat, embedding_layer: torch.nn.Module,
add_generation_prompt=True, ):
return_tensors='pt') token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds return prompt_embeds
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, def single_prompt_inference(
embedding_layer: torch.nn.Module): llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
chat = [{ ):
"role": "user", chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
"content": "Please tell me about the capital of France."
}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
outputs = llm.generate({ outputs = llm.generate(
"prompt_embeds": prompt_embeds, {
}) "prompt_embeds": prompt_embeds,
}
)
print("\n[Single Inference Output]") print("\n[Single Inference Output]")
print("-" * 30) print("-" * 30)
@ -61,34 +62,26 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
print("-" * 30) print("-" * 30)
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, def batch_prompt_inference(
embedding_layer: torch.nn.Module): llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
chats = [[{ ):
"role": "user", chats = [
"content": "Please tell me about the capital of France." [{"role": "user", "content": "Please tell me about the capital of France."}],
}], [{"role": "user", "content": "When is the day longest during the year?"}],
[{ [{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
"role": "user", ]
"content": "When is the day longest during the year?"
}],
[{
"role": "user",
"content": "Where is bigger, the moon or the sun?"
}]]
prompt_embeds_list = [ prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
] ]
outputs = llm.generate([{ outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
"prompt_embeds": embeds
} for embeds in prompt_embeds_list])
print("\n[Batch Inference Outputs]") print("\n[Batch Inference Outputs]")
print("-" * 30) print("-" * 30)
for i, o in enumerate(outputs): for i, o in enumerate(outputs):
print(f"Q{i+1}: {chats[i][0]['content']}") print(f"Q{i + 1}: {chats[i][0]['content']}")
print(f"A{i+1}: {o.outputs[0].text}\n") print(f"A{i + 1}: {o.outputs[0].text}\n")
print("-" * 30) print("-" * 30)

View File

@ -27,51 +27,55 @@ class QueryResult(NamedTuple):
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult: def get_mixed_modalities_query() -> QueryResult:
question = ("What is recited in the audio? " question = (
"What is the content of this image? Why is this video funny?") "What is recited in the audio? "
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" "What is the content of this image? Why is this video funny?"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" )
"<|vision_bos|><|IMAGE|><|vision_eos|>" prompt = (
"<|vision_bos|><|VIDEO|><|vision_eos|>" f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"{question}<|im_end|>\n" "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
f"<|im_start|>assistant\n") "<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
"multi_modal_data": { "multi_modal_data": {
"audio": "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate, "image": convert_image_mode(
"image": ImageAsset("cherry_blossom").pil_image, "RGB"
convert_image_mode( ),
ImageAsset("cherry_blossom").pil_image, "RGB"), "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
"video":
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
}, },
}, },
limit_mm_per_prompt={ limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
"audio": 1,
"image": 1,
"video": 1
},
) )
def get_use_audio_in_video_query() -> QueryResult: def get_use_audio_in_video_query() -> QueryResult:
question = ("Describe the content of the video, " question = (
"then convert what the baby say into text.") "Describe the content of the video, then convert what the baby say into text."
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" )
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" prompt = (
f"{question}<|im_end|>\n" f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"<|im_start|>assistant\n") "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16) asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000) audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " assert not envs.VLLM_USE_V1, (
"Please launch this example with " "V1 does not support use_audio_in_video. "
"`VLLM_USE_V1=0`.") "Please launch this example with "
"`VLLM_USE_V1=0`."
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
@ -83,20 +87,19 @@ def get_use_audio_in_video_query() -> QueryResult:
"use_audio_in_video": True, "use_audio_in_video": True,
}, },
}, },
limit_mm_per_prompt={ limit_mm_per_prompt={"audio": 1, "video": 1},
"audio": 1,
"video": 1
},
) )
def get_multi_audios_query() -> QueryResult: def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?" question = "Are these two audio clips the same?"
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" prompt = (
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|audio_bos|><|AUDIO|><|audio_eos|>" "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n" "<|audio_bos|><|AUDIO|><|audio_eos|>"
f"<|im_start|>assistant\n") f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
@ -124,18 +127,19 @@ def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B" model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]() query_result = query_map[args.query_type]()
llm = LLM(model=model_name, llm = LLM(
max_model_len=5632, model=model_name,
max_num_seqs=5, max_model_len=5632,
limit_mm_per_prompt=query_result.limit_mm_per_prompt, max_num_seqs=5,
seed=args.seed) limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64) sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
outputs = llm.generate(query_result.inputs, outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
sampling_params=sampling_params)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
@ -144,18 +148,23 @@ def main(args):
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'audio language models') "audio language models"
parser.add_argument('--query-type', )
'-q', parser.add_argument(
type=str, "--query-type",
default="mixed_modalities", "-q",
choices=query_map.keys(), type=str,
help='Query type.') default="mixed_modalities",
parser.add_argument("--seed", choices=query_map.keys(),
type=int, help="Query type.",
default=None, )
help="Set the seed when initializing `vllm.LLM`.") parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -17,10 +17,10 @@ def load_prompt() -> str:
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen( with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com" "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
"/Qwen2.5-1M/test-data/600k.txt", timeout=5,
timeout=5) as response: ) as response:
prompt = response.read().decode('utf-8') prompt = response.read().decode("utf-8")
return prompt return prompt
@ -41,18 +41,22 @@ def process_requests(llm: LLM, prompts: list[str]) -> None:
for output in outputs: for output in outputs:
prompt_token_ids = output.prompt_token_ids prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt length: {len(prompt_token_ids)}, " print(
f"Generated text: {generated_text!r}") f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}"
)
# Create an LLM. # Create an LLM.
def initialize_engine() -> LLM: def initialize_engine() -> LLM:
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M", llm = LLM(
max_model_len=1048576, model="Qwen/Qwen2.5-7B-Instruct-1M",
tensor_parallel_size=4, max_model_len=1048576,
enforce_eager=True, tensor_parallel_size=4,
enable_chunked_prefill=True, enforce_eager=True,
max_num_batched_tokens=131072) enable_chunked_prefill=True,
max_num_batched_tokens=131072,
)
return llm return llm
@ -62,5 +66,5 @@ def main():
process_requests(llm, [prompt]) process_requests(llm, [prompt])
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -12,6 +12,7 @@ inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework. to the OpenRLHF framework.
""" """
import os import os
import ray import ray
@ -26,7 +27,6 @@ from vllm.utils import get_ip, get_open_port
class MyLLM(LLM): class MyLLM(LLM):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# a hack to make the script work. # a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES # stop ray from manipulating CUDA_VISIBLE_DEVICES
@ -89,8 +89,7 @@ print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
print("-" * 50) print("-" * 50)
# set up the communication between the training process # set up the communication between the training process
@ -98,11 +97,13 @@ for output in outputs:
master_address = get_ip() master_address = get_ip()
master_port = get_open_port() master_port = get_open_port()
handle = llm.collective_rpc.remote("init_weight_update_group", handle = llm.collective_rpc.remote(
args=(master_address, master_port, 1, 3)) "init_weight_update_group", args=(master_address, master_port, 1, 3)
)
model_update_group = stateless_init_process_group(master_address, master_port, model_update_group = stateless_init_process_group(
0, 3, torch.device("cuda:0")) master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle) ray.get(handle)
# simulate training, modify the weights of the model. # simulate training, modify the weights of the model.
@ -111,8 +112,7 @@ for name, p in train_model.named_parameters():
# sync weight from the training process to the inference engine. # sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters(): for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight", handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle) ray.get(handle)
@ -126,6 +126,5 @@ print("-" * 50)
for output in outputs_updated: for output in outputs_updated:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
print("-" * 50) print("-" * 50)

View File

@ -9,6 +9,7 @@ The key points:
- Use cuda-ipc to pass tensors, since NCCL does not work when we have - Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU. multiple processes on the same GPU.
""" """
import os import os
import ray import ray
@ -20,7 +21,6 @@ from vllm import LLM
class MyLLM(LLM): class MyLLM(LLM):
def __init__(self, *args, bundle_indices: list, **kwargs): def __init__(self, *args, bundle_indices: list, **kwargs):
# a hack to make the script work. # a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES # stop ray from manipulating CUDA_VISIBLE_DEVICES
@ -29,17 +29,16 @@ class MyLLM(LLM):
# every worker will use 0.4 GPU, so that we can schedule # every worker will use 0.4 GPU, so that we can schedule
# 2 instances on the same GPUs. # 2 instances on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}") print(f"creating LLM with bundle_indices={bundle_indices}")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class RayTrainingActor: class RayTrainingActor:
def __init__(self): def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0") self.model.to("cuda:0")
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
@ -48,6 +47,7 @@ class RayTrainingActor:
# the argument for get_device_uuid is the index # the argument for get_device_uuid is the index
# of the GPU in the visible devices. # of the GPU in the visible devices.
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0) self.device_uuid = current_platform.get_device_uuid(0)
def report_device_id(self) -> str: def report_device_id(self) -> str:
@ -55,6 +55,7 @@ class RayTrainingActor:
def get_weight_ipc_handles(self): def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor from torch.multiprocessing.reductions import reduce_tensor
data = {} data = {}
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights # the training actor might only have a subset of the weights
@ -101,7 +102,7 @@ for bundle_index, training_actor in enumerate(training_actors):
print(f"training actor {bundle_index} is on {device_id}") print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id) training_actor_device_ids.append(device_id)
for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# IMPORTANT: when creating vLLM instances, we need to # IMPORTANT: when creating vLLM instances, we need to
# make sure there are no GPU activities on the target GPUs, # make sure there are no GPU activities on the target GPUs,
# otherwise, they will interfere with the vLLM memory profiling, # otherwise, they will interfere with the vLLM memory profiling,
@ -128,7 +129,8 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
for i, llm in enumerate(inference_engines): for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append( inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# check the placement # check the placement
@ -147,9 +149,10 @@ for actor in training_actors:
print("update the weights of the inference engines") print("update the weights of the inference engines")
for llm in inference_engines: for llm in inference_engines:
ray.get( ray.get(
llm.collective_rpc.remote("update_weights_from_ipc_handles", llm.collective_rpc.remote(
args=(ipc_handles, ))) "update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
print("check if the weights are updated") print("check if the weights are updated")
for llm in inference_engines: for llm in inference_engines:
assert ray.get( assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
llm.collective_rpc.remote("check_weights_changed", args=tuple()))

View File

@ -2,8 +2,7 @@
import torch import torch
def stateless_init_process_group(master_address, master_port, rank, world_size, def stateless_init_process_group(master_address, master_port, rank, world_size, device):
device):
""" """
vLLM provides `StatelessProcessGroup` to create a process group vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed. without considering the global process group in torch.distributed.
@ -13,10 +12,10 @@ def stateless_init_process_group(master_address, master_port, rank, world_size,
""" """
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port, pg = StatelessProcessGroup.create(
rank=rank, host=master_address, port=master_port, rank=rank, world_size=world_size
world_size=world_size) )
pynccl = PyNcclCommunicator(pg, device=device) pynccl = PyNcclCommunicator(pg, device=device)
return pynccl return pynccl
@ -31,9 +30,11 @@ class WorkerExtension:
should pass the full qualified name as `worker_extension_cls` argument. should pass the full qualified name as `worker_extension_cls` argument.
""" """
def init_weight_update_group(self, master_address, master_port, def init_weight_update_group(
rank_offset, world_size): self, master_address, master_port, rank_offset, world_size
):
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group( self.model_update_group = stateless_init_process_group(
master_address, master_address,
@ -45,9 +46,9 @@ class WorkerExtension:
def update_weight(self, name, dtype, shape): def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda") weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight, self.model_update_group.broadcast(
src=0, weight, src=0, stream=torch.cuda.current_stream()
stream=torch.cuda.current_stream()) )
self.model_runner.model.load_weights(weights=[(name, weight)]) self.model_runner.model.load_weights(weights=[(name, weight)])
@ -59,8 +60,7 @@ class WorkerExtension:
""" """
weights_updated = True weights_updated = True
for name, p in self.model_runner.model.named_parameters(): for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose( weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
p, torch.zeros_like(p))
return weights_updated return weights_updated
@ -76,6 +76,7 @@ class ColocateWorkerExtension:
def report_device_id(self) -> str: def report_device_id(self) -> str:
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index) self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid return self.device_uuid
@ -100,6 +101,5 @@ class ColocateWorkerExtension:
""" """
weights_updated = True weights_updated = True
for name, p in self.model_runner.model.named_parameters(): for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose( weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
p, torch.zeros_like(p))
return weights_updated return weights_updated

View File

@ -21,6 +21,7 @@ llm = LLM(
tensor_parallel_size=8, tensor_parallel_size=8,
) )
""" """
import dataclasses import dataclasses
import os import os
import shutil import shutil
@ -33,18 +34,18 @@ from vllm.utils import FlexibleArgumentParser
def parse_args(): def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
parser.add_argument("--output", parser.add_argument(
"-o", "--output", "-o", required=True, type=str, help="path to output checkpoint"
required=True, )
type=str, parser.add_argument(
help="path to output checkpoint") "--file-pattern", type=str, help="string pattern of saved filenames"
parser.add_argument("--file-pattern", )
type=str, parser.add_argument(
help="string pattern of saved filenames") "--max-file-size",
parser.add_argument("--max-file-size", type=str,
type=str, default=5 * 1024**3,
default=5 * 1024**3, help="max size (in bytes) of each safetensors file",
help="max size (in bytes) of each safetensors file") )
return parser.parse_args() return parser.parse_args()
@ -68,23 +69,23 @@ def main(args):
# For V1 engine, we need to use engine_core.save_sharded_state # For V1 engine, we need to use engine_core.save_sharded_state
print("Using V1 engine save path") print("Using V1 engine save path")
llm.llm_engine.engine_core.save_sharded_state( llm.llm_engine.engine_core.save_sharded_state(
path=args.output, path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
pattern=args.file_pattern, )
max_size=args.max_file_size)
else: else:
# For V0 engine # For V0 engine
print("Using V0 engine save path") print("Using V0 engine save path")
model_executor = llm.llm_engine.model_executor model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output, model_executor.save_sharded_state(
pattern=args.file_pattern, path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
max_size=args.max_file_size) )
# Copy metadata files to output directory # Copy metadata files to output directory
for file in os.listdir(model_path): for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)): if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(os.path.join(model_path, file), shutil.copytree(
os.path.join(args.output, file)) os.path.join(model_path, file), os.path.join(args.output, file)
)
else: else:
shutil.copy(os.path.join(model_path, file), args.output) shutil.copy(os.path.join(model_path, file), args.output)

View File

@ -15,20 +15,20 @@ from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
# Guided decoding by Choice (list of possible options) # Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams( guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
sampling_params_choice = SamplingParams(
guided_decoding=guided_decoding_params_choice)
prompt_choice = "Classify this sentiment: vLLM is wonderful!" prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Guided decoding by Regex # Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams( sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"]) guided_decoding=guided_decoding_params_regex, stop=["\n"]
)
prompt_regex = ( prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma." "Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:" "End in .com and new line. Example result:"
"alan.turing@enigma.com\n") "alan.turing@enigma.com\n"
)
# Guided decoding by JSON using Pydantic schema # Guided decoding by JSON using Pydantic schema
@ -47,10 +47,11 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema) guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams( sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
guided_decoding=guided_decoding_params_json) prompt_json = (
prompt_json = ("Generate a JSON with the brand, model and car_type of" "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's") "the most iconic car from the 90's"
)
# Guided decoding by Grammar # Guided decoding by Grammar
simplified_sql_grammar = """ simplified_sql_grammar = """
@ -61,12 +62,11 @@ table ::= "table_1 " | "table_2 "
condition ::= column "= " number condition ::= column "= " number
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
guided_decoding_params_grammar = GuidedDecodingParams( guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
grammar=simplified_sql_grammar) sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
sampling_params_grammar = SamplingParams( prompt_grammar = (
guided_decoding=guided_decoding_params_grammar) "Generate an SQL query to show the 'username' and 'email'from the 'users' table."
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'" )
"from the 'users' table.")
def format_output(title: str, output: str): def format_output(title: str, output: str):
@ -90,8 +90,7 @@ def main():
json_output = generate_output(prompt_json, sampling_params_json, llm) json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output) format_output("Guided decoding by JSON", json_output)
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
llm)
format_output("Guided decoding by Grammar", grammar_output) format_output("Guided decoding by Grammar", grammar_output)

View File

@ -45,8 +45,7 @@ if dist.get_rank() == 0:
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
f"Generated text: {generated_text!r}\n")
print("-" * 50) print("-" * 50)
""" """
Further tips: Further tips:

View File

@ -20,10 +20,12 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
def main(): def main():
# Set `enforce_eager=True` to avoid ahead-of-time compilation. # Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`. # In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", llm = LLM(
max_num_batched_tokens=64, model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=4, max_num_batched_tokens=64,
max_model_len=128) max_num_seqs=4,
max_model_len=128,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
print("-" * 50) print("-" * 50)
for output, answer in zip(outputs, answers): for output, answer in zip(outputs, answers):

View File

@ -6,6 +6,7 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
import os import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
@ -49,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}" prompts = [
"<|im_end|>\n<|im_start|>assistant\n") (
for question in questions] f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n"
)
for question in questions
]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
@ -135,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
) )
prompts = [ prompts = [
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" for question in questions
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
@ -198,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [("<bos><start_of_turn>user\n" prompts = [
f"<start_of_image>{question}<end_of_turn>\n" (
"<start_of_turn>model\n") for question in questions] "<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -225,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>" for question in questions {question}<|assistant|>"
for question in questions
] ]
stop_token_ids = [151329, 151336, 151338] stop_token_ids = [151329, 151336, 151338]
@ -250,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi # Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m # https://huggingface.co/h2oai/h2ovl-mississippi-800m
@ -284,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge". # if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={ mm_processor_kwargs={
"size": { "size": {"longest_edge": 3 * 364},
"longest_edge": 3 * 364
},
}, },
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [( prompts = [
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:" (f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:")
) for question in questions] for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -311,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2, max_num_seqs=2,
enforce_eager=True, enforce_eager=True,
mm_processor_kwargs={ mm_processor_kwargs={
"max_image_size": { "max_image_size": {"longest_edge": 384},
"longest_edge": 384
},
}, },
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
@ -330,7 +335,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
# InternVL # InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData: def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "OpenGVLab/InternVL3-2B" model_name = "OpenGVLab/InternVL3-2B"
engine_args = EngineArgs( engine_args = EngineArgs(
@ -345,15 +349,14 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video": elif modality == "video":
placeholder = "<video>" placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"{placeholder}\n{question}"}]
'role': 'user', for question in questions
'content': f"{placeholder}\n{question}" ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for InternVL # Stop tokens for InternVL
# models variants may have different stop tokens # models variants may have different stop tokens
@ -361,9 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [ stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
token_id for token_id in stop_token_ids if token_id is not None
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -379,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
f"<|media_pad|><|media_end|>{question}<|im_end|>" f"<|media_pad|><|media_end|>{question}<|im_end|>"
"<|im_assistant|>assistant<|im_middle|>" for question in questions "<|im_assistant|>assistant<|im_middle|>"
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
@ -399,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
def run_llava(questions: list[str], modality: str) -> ModelRequestData: def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
prompts = [ prompts = [f"USER: <image>\n{question}\nASSISTANT:" for question in questions]
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
@ -434,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
# LlaVA-NeXT-Video # LlaVA-NeXT-Video
# Currently only support for video input # Currently only support for video input
def run_llava_next_video(questions: list[str], def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData:
modality: str) -> ModelRequestData:
assert modality == "video" assert modality == "video"
prompts = [ prompts = [f"USER: <video>\n{question} ASSISTANT:" for question in questions]
f"USER: <video>\n{question} ASSISTANT:" for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf", model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192, max_model_len=8192,
@ -455,19 +452,19 @@ def run_llava_next_video(questions: list[str],
# LLaVA-OneVision # LLaVA-OneVision
def run_llava_onevision(questions: list[str], def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
modality: str) -> ModelRequestData:
if modality == "video": if modality == "video":
prompts = [ prompts = [
f"<|im_start|>user <video>\n{question}<|im_end|> \ f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
elif modality == "image": elif modality == "image":
prompts = [ prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \ f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
@ -486,11 +483,8 @@ def run_llava_onevision(questions: list[str],
def run_mantis(questions: list[str], modality: str) -> ModelRequestData: def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" # noqa: E501
prompts = [ prompts = [llama3_template.format(f"{question}\n<image>") for question in questions]
llama3_template.format(f"{question}\n<image>")
for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="TIGER-Lab/Mantis-8B-siglip-llama3", model="TIGER-Lab/Mantis-8B-siglip-llama3",
@ -530,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# 2.6: image, video # 2.6: image, video
# o2.6: image, video, audio # o2.6: image, video, audio
# model_name = "openbmb/MiniCPM-o-2_6" # model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=4096, max_model_len=4096,
@ -547,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
# 2.6 / o2.6 # 2.6 / o2.6
stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
modality_placeholder = { modality_placeholder = {
@ -557,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
prompts = [ prompts = [
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
[{ [
'role': 'user', {
'content': f"{modality_placeholder[modality]}\n{question}" "role": "user",
}], "content": f"{modality_placeholder[modality]}\n{question}",
}
],
tokenize=False, tokenize=False,
add_generation_prompt=True) for question in questions add_generation_prompt=True,
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
@ -622,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
) )
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{ messages = [
"role": [
"user", {
"content": [{ "role": "user",
"type": "image" "content": [{"type": "image"}, {"type": "text", "text": question}],
}, { }
"type": "text", ]
"text": question for question in questions
}] ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, add_generation_prompt=True, tokenize=False
add_generation_prompt=True, )
tokenize=False)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -657,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
) )
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{ messages = [
"role": [
"user", {
"content": [{ "role": "user",
"type": "image" "content": [{"type": "image"}, {"type": "text", "text": f"{question}"}],
}, { }
"type": "text", ]
"text": f"{question}" for question in questions
}] ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, add_generation_prompt=True, tokenize=False
add_generation_prompt=True, )
tokenize=False)
stop_token_ids = None stop_token_ids = None
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -693,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \ f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
@ -717,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -748,15 +742,13 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -847,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
# we have to manually specify the path of the lora weights. # we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora") vision_lora_path = os.path.join(model_path, "vision-lora")
prompts = [ prompts = [
f"<|user|><|image_1|>{question}<|end|><|assistant|>" f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
@ -915,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2-VL # Qwen2-VL
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
@ -936,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>" placeholder = "<|video_pad|>"
prompts = [ prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" (
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"{question}<|im_end|>\n" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
"<|im_start|>assistant\n") for question in questions f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
@ -950,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2.5-VL # Qwen2.5-VL
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct" model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
@ -971,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>" placeholder = "<|video_pad|>"
prompts = [ prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" (
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"{question}<|im_end|>\n" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
"<|im_start|>assistant\n") for question in questions f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
@ -1007,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n" prompts = [
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" (
f"{question}<|im_end|>\n" f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions] f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompts=prompts, prompts=prompts,
@ -1032,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
# Stop tokens for SkyworkR1V # Stop tokens for SkyworkR1V
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
@ -1104,8 +1103,7 @@ def get_multi_modal_input(args):
""" """
if args.modality == "image": if args.modality == "image":
# Input image and question # Input image and question
image = convert_image_mode( image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [ img_questions = [
"What is the content of this image?", "What is the content of this image?",
"Describe the content of this image in detail.", "Describe the content of this image in detail.",
@ -1120,8 +1118,7 @@ def get_multi_modal_input(args):
if args.modality == "video": if args.modality == "video":
# Input video and question # Input video and question
video = VideoAsset(name="baby_reading", video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
num_frames=args.num_frames).np_ndarrays
vid_questions = ["Why is this video funny?"] vid_questions = ["Why is this video funny?"]
return { return {
@ -1133,12 +1130,13 @@ def get_multi_modal_input(args):
raise ValueError(msg) raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, def apply_image_repeat(
prompts: list[str], modality): image_repeat_prob, num_prompts, data, prompts: list[str], modality
):
"""Repeats images with provided probability of "image_repeat_prob". """Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache. Used to simulate hit/miss for the MM preprocessor cache.
""" """
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0) assert image_repeat_prob <= 1.0 and image_repeat_prob >= 0
no_yes = [0, 1] no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob] probs = [1.0 - image_repeat_prob, image_repeat_prob]
@ -1153,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
new_val = (i // 256 // 256, i // 256, i % 256) new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val) cur_image.putpixel((0, 0), new_val)
inputs.append({ inputs.append(
"prompt": prompts[i % len(prompts)], {
"multi_modal_data": { "prompt": prompts[i % len(prompts)],
modality: cur_image "multi_modal_data": {modality: cur_image},
} }
}) )
return inputs return inputs
@ -1167,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
def time_counter(enable: bool): def time_counter(enable: bool):
if enable: if enable:
import time import time
start_time = time.time() start_time = time.time()
yield yield
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@ -1179,54 +1178,65 @@ def time_counter(enable: bool):
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for text generation') "vision language models for text generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="llava", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="llava",
parser.add_argument('--num-prompts', choices=model_example_map.keys(),
type=int, help='Huggingface "model_type".',
default=4, )
help='Number of prompts to run.') parser.add_argument(
parser.add_argument('--modality', "--num-prompts", type=int, default=4, help="Number of prompts to run."
type=str, )
default="image", parser.add_argument(
choices=['image', 'video'], "--modality",
help='Modality of the input.') type=str,
parser.add_argument('--num-frames', default="image",
type=int, choices=["image", "video"],
default=16, help="Modality of the input.",
help='Number of frames to extract from the video.') )
parser.add_argument("--seed", parser.add_argument(
type=int, "--num-frames",
default=None, type=int,
help="Set the seed when initializing `vllm.LLM`.") default=16,
help="Number of frames to extract from the video.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument( parser.add_argument(
'--image-repeat-prob', "--image-repeat-prob",
type=float, type=float,
default=None, default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache' help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)",
' (if enabled)') )
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', "--disable-mm-preprocessor-cache",
action='store_true', action="store_true",
help='If True, disables caching of multi-modal preprocessor/mapper.') help="If True, disables caching of multi-modal preprocessor/mapper.",
)
parser.add_argument( parser.add_argument(
'--time-generate', "--time-generate",
action='store_true', action="store_true",
help='If True, then print the total generate() call time') help="If True, then print the total generate() call time",
)
parser.add_argument( parser.add_argument(
'--use-different-prompt-per-request', "--use-different-prompt-per-request",
action='store_true', action="store_true",
help='If True, then use different prompt (with the same multi-modal ' help="If True, then use different prompt (with the same multi-modal "
'data) for each request.') "data) for each request.",
)
return parser.parse_args() return parser.parse_args()
@ -1245,7 +1255,8 @@ def main(args):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | { engine_args = asdict(req_data.engine_args) | {
"seed": args.seed, "seed": args.seed,
@ -1254,44 +1265,46 @@ def main(args):
llm = LLM(**engine_args) llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`. # Don't want to check the flag multiple times, so just hijack `prompts`.
prompts = req_data.prompts if args.use_different_prompt_per_request else [ prompts = (
req_data.prompts[0] req_data.prompts
] if args.use_different_prompt_per_request
else [req_data.prompts[0]]
)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, sampling_params = SamplingParams(
max_tokens=64, temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
assert args.num_prompts > 0 assert args.num_prompts > 0
if args.num_prompts == 1: if args.num_prompts == 1:
# Single inference # Single inference
inputs = { inputs = {
"prompt": prompts[0], "prompt": prompts[0],
"multi_modal_data": { "multi_modal_data": {modality: data},
modality: data
},
} }
else: else:
# Batch inference # Batch inference
if args.image_repeat_prob is not None: if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob" # Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob, inputs = apply_image_repeat(
args.num_prompts, data, prompts, args.image_repeat_prob, args.num_prompts, data, prompts, modality
modality) )
else: else:
# Use the same image for all prompts # Use the same image for all prompts
inputs = [{ inputs = [
"prompt": prompts[i % len(prompts)], {
"multi_modal_data": { "prompt": prompts[i % len(prompts)],
modality: data "multi_modal_data": {modality: data},
}, }
} for i in range(args.num_prompts)] for i in range(args.num_prompts)
]
# Add LoRA request if applicable # Add LoRA request if applicable
lora_request = (req_data.lora_requests * lora_request = (
args.num_prompts if req_data.lora_requests else None) req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
with time_counter(args.time_generate): with time_counter(args.time_generate):
outputs = llm.generate( outputs = llm.generate(

View File

@ -6,6 +6,7 @@ the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple):
def run_e5_v(query: Query) -> ModelRequestData: def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
if query["modality"] == "text": if query["modality"] == "text":
text = query["text"] text = query["text"]
prompt = llama3_template.format( prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
f"{text}\nSummary above sentence in one word: ")
image = None image = None
elif query["modality"] == "image": elif query["modality"] == "image":
prompt = llama3_template.format( prompt = llama3_template.format("<image>\nSummary above image in one word: ")
"<image>\nSummary above image in one word: ")
image = query["image"] image = query["image"]
else: else:
modality = query['modality'] modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'") raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs( engine_args = EngineArgs(
@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
image = query["image"] image = query["image"]
elif query["modality"] == "text+image": elif query["modality"] == "text+image":
text = query["text"] text = query["text"]
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 prompt = (
f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
)
image = query["image"] image = query["image"]
else: else:
modality = query['modality'] modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'") raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs( engine_args = EngineArgs(
@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
if req_data.image is not None: if req_data.image is not None:
mm_data["image"] = req_data.image mm_data["image"] = req_data.image
outputs = llm.embed({ outputs = llm.embed(
"prompt": req_data.prompt, {
"multi_modal_data": mm_data, "prompt": req_data.prompt,
}) "multi_modal_data": mm_data,
}
)
print("-" * 50) print("-" * 50)
for output in outputs: for output in outputs:
@ -164,23 +168,30 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for multimodal embedding') "vision language models for multimodal embedding"
parser.add_argument('--model-name', )
'-m', parser.add_argument(
type=str, "--model-name",
default="vlm2vec", "-m",
choices=model_example_map.keys(), type=str,
help='The name of the embedding model.') default="vlm2vec",
parser.add_argument('--modality', choices=model_example_map.keys(),
type=str, help="The name of the embedding model.",
default="image", )
choices=get_args(QueryModality), parser.add_argument(
help='Modality of the input.') "--modality",
parser.add_argument("--seed", type=str,
type=int, default="image",
default=None, choices=get_args(QueryModality),
help="Set the seed when initializing `vllm.LLM`.") help="Modality of the input.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -4,6 +4,7 @@ This example shows how to use vLLM for running offline inference with
multi-image input on vision language models for text generation, multi-image input on vision language models for text generation,
using the chat template defined by the model. using the chat template defined by the model.
""" """
import os import os
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls) placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" prompt = (
"<|im_start|>assistant\n") f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n<|im_start|>assistant\n"
)
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData( return ModelRequestData(
@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_deepseek_vl2(question: str, def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny" model_name = "deepseek-ai/deepseek-vl2-tiny"
engine_args = EngineArgs( engine_args = EngineArgs(
@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholder = "".join(f"image_{i}:<image>\n" placeholder = "".join(
for i, _ in enumerate(image_urls, start=1)) f"image_{i}:<image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:" prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
return ModelRequestData( return ModelRequestData(
@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi # Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m # https://huggingface.co/h2oai/h2ovl-mississippi-800m
@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge". # if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={ mm_processor_kwargs={
"size": { "size": {"longest_edge": 2 * 364},
"longest_edge": 2 * 364
},
}, },
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={ mm_processor_kwargs={
"max_image_size": { "max_image_size": {"longest_edge": 384},
"longest_edge": 384
},
}, },
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 )
prompt = (
f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for InternVL # Stop tokens for InternVL
# models variants may have different stop tokens # models variants may have different stop tokens
@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name, processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -449,15 +442,15 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -509,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"num_crops": 4}, mm_processor_kwargs={"num_crops": 4},
) )
placeholders = "\n".join(f"<|image_{i}|>" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
return ModelRequestData( return ModelRequestData(
@ -542,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"dynamic_hd": 4}, mm_processor_kwargs={"dynamic_hd": 4},
) )
placeholders = "".join(f"<|image_{i}|>" placeholders = "".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1))
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>" prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
return ModelRequestData( return ModelRequestData(
@ -554,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_qwen_vl_chat(question: str, def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat" model_name = "Qwen/Qwen-VL-Chat"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
@ -565,24 +557,26 @@ def load_qwen_vl_chat(question: str,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "".join(f"Picture {i}: <img></img>\n" placeholders = "".join(
for i, _ in enumerate(image_urls, start=1)) f"Picture {i}: <img></img>\n" for i, _ in enumerate(image_urls, start=1)
)
# This model does not have a chat_template attribute on its tokenizer, # This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the # so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model: # generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265 # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501 chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages, prompt = tokenizer.apply_chat_template(
tokenize=False, messages,
add_generation_prompt=True, tokenize=False,
chat_template=chat_template) add_generation_prompt=True,
chat_template=chat_template,
)
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
@ -600,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try: try:
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
except ModuleNotFoundError: except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not ' print(
'be automatically resized. You can enable this functionality by ' "WARNING: `qwen-vl-utils` not installed, input images will not "
'`pip install qwen-vl-utils`.') "be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
@ -616,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {
}, { "role": "user",
"role": "content": [
"user", *placeholders,
"content": [ {"type": "text", "text": question},
*placeholders, ],
{ },
"type": "text", ]
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
if process_vision_info is None: if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls] image_data = [fetch_image(url) for url in image_urls]
@ -653,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try: try:
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
except ModuleNotFoundError: except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not ' print(
'be automatically resized. You can enable this functionality by ' "WARNING: `qwen-vl-utils` not installed, input images will not "
'`pip install qwen-vl-utils`.') "be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None process_vision_info = None
model_name = "Qwen/Qwen2.5-VL-3B-Instruct" model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
@ -668,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {
}, { "role": "user",
"role": "content": [
"user", *placeholders,
"content": [ {"type": "text", "text": question},
*placeholders, ],
{ },
"type": "text", ]
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
if process_vision_info is None: if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls] image_data = [fetch_image(url) for url in image_urls]
else: else:
image_data, _ = process_vision_info(messages, image_data, _ = process_vision_info(messages, return_video_kwargs=False)
return_video_kwargs=False)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
@ -726,23 +715,20 @@ model_example_map = {
} }
def run_generate(model, question: str, image_urls: list[str], def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]):
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=256, temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
outputs = llm.generate( outputs = llm.generate(
{ {
"prompt": req_data.prompt, "prompt": req_data.prompt,
"multi_modal_data": { "multi_modal_data": {"image": req_data.image_data},
"image": req_data.image_data
},
}, },
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=req_data.lora_requests, lora_request=req_data.lora_requests,
@ -755,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str],
print("-" * 50) print("-" * 50)
def run_chat(model: str, question: str, image_urls: list[str], def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]):
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=256, temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
outputs = llm.chat( outputs = llm.chat(
[{ [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": question, "text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
}, },
} for image_url in image_urls), *(
], {
}], "type": "image_url",
"image_url": {"url": image_url},
}
for image_url in image_urls
),
],
}
],
sampling_params=sampling_params, sampling_params=sampling_params,
chat_template=req_data.chat_template, chat_template=req_data.chat_template,
lora_request=req_data.lora_requests, lora_request=req_data.lora_requests,
@ -801,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str],
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models that support multi-image input for text ' "vision language models that support multi-image input for text "
'generation') "generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="phi3_v", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="phi3_v",
parser.add_argument("--method", choices=model_example_map.keys(),
type=str, help='Huggingface "model_type".',
default="generate", )
choices=["generate", "chat"], parser.add_argument(
help="The method to run in `vllm.LLM`.") "--method",
parser.add_argument("--seed", type=str,
type=int, default="generate",
default=None, choices=["generate", "chat"],
help="Set the seed when initializing `vllm.LLM`.") help="The method to run in `vllm.LLM`.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument( parser.add_argument(
"--num-images", "--num-images",
"-n", "-n",
type=int, type=int,
choices=list(range(1, choices=list(range(1, len(IMAGE_URLS) + 1)), # the max number of images
len(IMAGE_URLS) + 1)), # the max number of images
default=2, default=2,
help="Number of images to use for the demo.") help="Number of images to use for the demo.",
)
return parser.parse_args() return parser.parse_args()
@ -835,7 +830,7 @@ def main(args: Namespace):
method = args.method method = args.method
seed = args.seed seed = args.seed
image_urls = IMAGE_URLS[:args.num_images] image_urls = IMAGE_URLS[: args.num_images]
if method == "generate": if method == "generate":
run_generate(model, QUESTION, image_urls, seed) run_generate(model, QUESTION, image_urls, seed)

View File

@ -17,16 +17,15 @@ import requests
def clear_line(n: int = 1) -> None: def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A' LINE_UP = "\033[1A"
LINE_CLEAR = '\x1b[2K' LINE_CLEAR = "\x1b[2K"
for _ in range(n): for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, def post_http_request(
api_url: str, prompt: str, api_url: str, n: int = 1, stream: bool = False
n: int = 1, ) -> requests.Response:
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
"prompt": prompt, "prompt": prompt,
@ -35,17 +34,14 @@ def post_http_request(prompt: str,
"max_tokens": 16, "max_tokens": 16,
"stream": stream, "stream": stream,
} }
response = requests.post(api_url, response = requests.post(api_url, headers=headers, json=pload, stream=stream)
headers=headers,
json=pload,
stream=stream)
return response return response
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(
decode_unicode=False, chunk_size=8192, decode_unicode=False, delimiter=b"\n"
delimiter=b"\n"): ):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"] output = data["text"]

View File

@ -6,6 +6,7 @@ Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base run: vllm serve BAAI/bge-reranker-base
""" """
from typing import Union from typing import Union
import cohere import cohere
@ -16,28 +17,28 @@ model = "BAAI/bge-reranker-base"
query = "What is the capital of France?" query = "What is the capital of France?"
documents = [ documents = [
"The capital of France is Paris", "Reranking is fun!", "The capital of France is Paris",
"vLLM is an open-source framework for fast AI serving" "Reranking is fun!",
"vLLM is an open-source framework for fast AI serving",
] ]
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, def cohere_rerank(
documents: list[str]) -> dict: client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents) return client.rerank(model=model, query=query, documents=documents)
def main(): def main():
# cohere v1 client # cohere v1 client
cohere_v1 = cohere.Client(base_url="http://localhost:8000", cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
api_key="sk-fake-key")
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
print("-" * 50) print("-" * 50)
print("rerank_v1_result:\n", rerank_v1_result) print("rerank_v1_result:\n", rerank_v1_result)
print("-" * 50) print("-" * 50)
# or the v2 # or the v2
cohere_v2 = cohere.ClientV2("sk-fake-key", cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
base_url="http://localhost:8000")
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
print("rerank_v2_result:\n", rerank_v2_result) print("rerank_v2_result:\n", rerank_v2_result)
print("-" * 50) print("-" * 50)

View File

@ -13,6 +13,7 @@ launch this proxy demo through:
Note: This demo will be removed once the PDController implemented in PR 15343 Note: This demo will be removed once the PDController implemented in PR 15343
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd. (https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
""" """
import argparse import argparse
import ipaddress import ipaddress
import itertools import itertools
@ -26,8 +27,7 @@ from typing import Callable, Optional
import aiohttp import aiohttp
import requests import requests
import uvicorn import uvicorn
from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
Request, status)
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@ -36,24 +36,24 @@ logging.basicConfig(level=logging.INFO)
class SchedulingPolicy(ABC): class SchedulingPolicy(ABC):
@abstractmethod @abstractmethod
def schedule(self, cycler: itertools.cycle): def schedule(self, cycler: itertools.cycle):
raise NotImplementedError("Scheduling Proxy is not set.") raise NotImplementedError("Scheduling Proxy is not set.")
class Proxy: class Proxy:
def __init__( def __init__(
self, self,
prefill_instances: list[str], prefill_instances: list[str],
decode_instances: list[str], decode_instances: list[str],
model: str, model: str,
scheduling_policy: SchedulingPolicy, scheduling_policy: SchedulingPolicy,
custom_create_completion: Optional[Callable[[Request], custom_create_completion: Optional[
StreamingResponse]] = None, Callable[[Request], StreamingResponse]
custom_create_chat_completion: Optional[Callable[ ] = None,
[Request], StreamingResponse]] = None, custom_create_chat_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
): ):
self.prefill_instances = prefill_instances self.prefill_instances = prefill_instances
self.decode_instances = decode_instances self.decode_instances = decode_instances
@ -68,30 +68,30 @@ class Proxy:
def setup_routes(self): def setup_routes(self):
self.router.post( self.router.post(
"/v1/completions", "/v1/completions", dependencies=[Depends(self.validate_json_request)]
dependencies=[ )(
Depends(self.validate_json_request) self.custom_create_completion
])(self.custom_create_completion if self. if self.custom_create_completion
custom_create_completion else self.create_completion) else self.create_completion
)
self.router.post( self.router.post(
"/v1/chat/completions", "/v1/chat/completions", dependencies=[Depends(self.validate_json_request)]
dependencies=[ )(
Depends(self.validate_json_request) self.custom_create_chat_completion
])(self.custom_create_chat_completion if self. if self.custom_create_chat_completion
custom_create_chat_completion else self.create_chat_completion) else self.create_chat_completion
self.router.get("/status", )
response_class=JSONResponse)(self.get_status) self.router.get("/status", response_class=JSONResponse)(self.get_status)
self.router.post("/instances/add", self.router.post(
dependencies=[Depends(self.api_key_authenticate) "/instances/add", dependencies=[Depends(self.api_key_authenticate)]
])(self.add_instance_endpoint) )(self.add_instance_endpoint)
async def validate_json_request(self, raw_request: Request): async def validate_json_request(self, raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower() content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json": if content_type != "application/json":
raise HTTPException( raise HTTPException(
status_code=415, status_code=415,
detail= detail="Unsupported Media Type: Only 'application/json' is allowed",
"Unsupported Media Type: Only 'application/json' is allowed",
) )
def api_key_authenticate(self, x_api_key: str = Header(...)): def api_key_authenticate(self, x_api_key: str = Header(...)):
@ -103,8 +103,7 @@ class Proxy:
detail="Server configuration error.", detail="Server configuration error.",
) )
if x_api_key != expected_api_key: if x_api_key != expected_api_key:
logger.warning("Unauthorized access attempt with API Key: %s", logger.warning("Unauthorized access attempt with API Key: %s", x_api_key)
x_api_key)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Forbidden: Invalid API Key.", detail="Forbidden: Invalid API Key.",
@ -113,8 +112,7 @@ class Proxy:
async def validate_instance(self, instance: str) -> bool: async def validate_instance(self, instance: str) -> bool:
url = f"http://{instance}/v1/models" url = f"http://{instance}/v1/models"
try: try:
async with aiohttp.ClientSession( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client:
timeout=AIOHTTP_TIMEOUT) as client:
logger.info("Verifying %s ...", instance) logger.info("Verifying %s ...", instance)
async with client.get(url) as response: async with client.get(url) as response:
if response.status == 200: if response.status == 200:
@ -122,12 +120,15 @@ class Proxy:
if "data" in data and len(data["data"]) > 0: if "data" in data and len(data["data"]) > 0:
model_cur = data["data"][0].get("id", "") model_cur = data["data"][0].get("id", "")
if model_cur == self.model: if model_cur == self.model:
logger.info("Instance: %s could be added.", logger.info("Instance: %s could be added.", instance)
instance)
return True return True
else: else:
logger.warning("Mismatch model %s : %s != %s", logger.warning(
instance, model_cur, self.model) "Mismatch model %s : %s != %s",
instance,
model_cur,
self.model,
)
return False return False
else: else:
return False return False
@ -147,48 +148,47 @@ class Proxy:
instance_type = data.get("type") instance_type = data.get("type")
instance = data.get("instance") instance = data.get("instance")
if instance_type not in ["prefill", "decode"]: if instance_type not in ["prefill", "decode"]:
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid instance type.")
detail="Invalid instance type.")
if not instance or ":" not in instance: if not instance or ":" not in instance:
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid instance format.")
detail="Invalid instance format.")
host, port_str = instance.split(":") host, port_str = instance.split(":")
try: try:
if host != "localhost": if host != "localhost":
ipaddress.ip_address(host) ipaddress.ip_address(host)
port = int(port_str) port = int(port_str)
if not (0 < port < 65536): if not (0 < port < 65536):
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid port number.")
detail="Invalid port number.")
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, raise HTTPException(
detail="Invalid instance address.") from e status_code=400, detail="Invalid instance address."
) from e
is_valid = await self.validate_instance(instance) is_valid = await self.validate_instance(instance)
if not is_valid: if not is_valid:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance validation failed.") status_code=400, detail="Instance validation failed."
)
if instance_type == "prefill": if instance_type == "prefill":
if instance not in self.prefill_instances: if instance not in self.prefill_instances:
self.prefill_instances.append(instance) self.prefill_instances.append(instance)
self.prefill_cycler = itertools.cycle( self.prefill_cycler = itertools.cycle(self.prefill_instances)
self.prefill_instances)
else: else:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance already exists.") status_code=400, detail="Instance already exists."
)
else: else:
if instance not in self.decode_instances: if instance not in self.decode_instances:
self.decode_instances.append(instance) self.decode_instances.append(instance)
self.decode_cycler = itertools.cycle(self.decode_instances) self.decode_cycler = itertools.cycle(self.decode_instances)
else: else:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance already exists.") status_code=400, detail="Instance already exists."
)
return JSONResponse(content={ return JSONResponse(
"message": content={"message": f"Added {instance} to {instance_type}_instances."}
f"Added {instance} to {instance_type}_instances." )
})
except HTTPException as http_exc: except HTTPException as http_exc:
raise http_exc raise http_exc
except Exception as e: except Exception as e:
@ -197,16 +197,16 @@ class Proxy:
async def forward_request(self, url, data, use_chunked=True): async def forward_request(self, url, data, use_chunked=True):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
try: try:
async with session.post(url=url, json=data, async with session.post(
headers=headers) as response: url=url, json=data, headers=headers
) as response:
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
if use_chunked: if use_chunked:
async for chunk_bytes in response.content.iter_chunked( # noqa: E501 async for chunk_bytes in response.content.iter_chunked( # noqa: E501
1024): 1024
):
yield chunk_bytes yield chunk_bytes
else: else:
content = await response.read() content = await response.read()
@ -217,20 +217,21 @@ class Proxy:
error_content = json.loads(error_content) error_content = json.loads(error_content)
except json.JSONDecodeError: except json.JSONDecodeError:
error_content = error_content error_content = error_content
logger.error("Request failed with status %s: %s", logger.error(
response.status, error_content) "Request failed with status %s: %s",
response.status,
error_content,
)
raise HTTPException( raise HTTPException(
status_code=response.status, status_code=response.status,
detail= detail=f"Request failed with status {response.status}: "
f"Request failed with status {response.status}: "
f"{error_content}", f"{error_content}",
) )
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error("ClientError occurred: %s", str(e)) logger.error("ClientError occurred: %s", str(e))
raise HTTPException( raise HTTPException(
status_code=502, status_code=502,
detail= detail="Bad Gateway: Error communicating with upstream server.",
"Bad Gateway: Error communicating with upstream server.",
) from e ) from e
except Exception as e: except Exception as e:
logger.error("Unexpected error: %s", str(e)) logger.error("Unexpected error: %s", str(e))
@ -258,8 +259,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler) prefill_instance = self.schedule(self.prefill_cycler)
try: try:
async for _ in self.forward_request( async for _ in self.forward_request(
f"http://{prefill_instance}/v1/completions", f"http://{prefill_instance}/v1/completions", kv_prepare_request
kv_prepare_request): ):
continue continue
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance) self.remove_instance_endpoint("prefill", prefill_instance)
@ -270,7 +271,8 @@ class Proxy:
try: try:
generator = self.forward_request( generator = self.forward_request(
f"http://{decode_instance}/v1/completions", request) f"http://{decode_instance}/v1/completions", request
)
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance) self.remove_instance_endpoint("decode", decode_instance)
raise http_exc raise http_exc
@ -295,8 +297,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler) prefill_instance = self.schedule(self.prefill_cycler)
try: try:
async for _ in self.forward_request( async for _ in self.forward_request(
f"http://{prefill_instance}/v1/chat/completions", f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request
kv_prepare_request): ):
continue continue
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance) self.remove_instance_endpoint("prefill", prefill_instance)
@ -306,8 +308,8 @@ class Proxy:
try: try:
generator = self.forward_request( generator = self.forward_request(
"http://" + decode_instance + "/v1/chat/completions", "http://" + decode_instance + "/v1/chat/completions", request
request) )
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance) self.remove_instance_endpoint("decode", decode_instance)
raise http_exc raise http_exc
@ -318,20 +320,20 @@ class Proxy:
error_messages = [str(e) for e in exc_info if e] error_messages = [str(e) for e in exc_info if e]
print("Error occurred in disagg proxy server") print("Error occurred in disagg proxy server")
print(error_messages) print(error_messages)
return StreamingResponse(content=iter(error_messages), return StreamingResponse(
media_type="text/event-stream") content=iter(error_messages), media_type="text/event-stream"
)
def remove_instance_endpoint(self, instance_type, instance): def remove_instance_endpoint(self, instance_type, instance):
if (instance_type == "decode" and instance in self.decode_instances): if instance_type == "decode" and instance in self.decode_instances:
self.decode_instances.remove(instance) self.decode_instances.remove(instance)
self.decode_cycler = itertools.cycle(self.decode_instances) self.decode_cycler = itertools.cycle(self.decode_instances)
if (instance_type == "prefill" and instance in self.decode_instances): if instance_type == "prefill" and instance in self.decode_instances:
self.prefill_instances.remove(instance) self.prefill_instances.remove(instance)
self.prefill_cycler = itertools.cycle(self.decode_instances) self.prefill_cycler = itertools.cycle(self.decode_instances)
class RoundRobinSchedulingPolicy(SchedulingPolicy): class RoundRobinSchedulingPolicy(SchedulingPolicy):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -340,15 +342,12 @@ class RoundRobinSchedulingPolicy(SchedulingPolicy):
class ProxyServer: class ProxyServer:
def __init__( def __init__(
self, self,
args: argparse.Namespace, args: argparse.Namespace,
scheduling_policy: Optional[SchedulingPolicy] = None, scheduling_policy: Optional[SchedulingPolicy] = None,
create_completion: Optional[Callable[[Request], create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
StreamingResponse]] = None, create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request],
StreamingResponse]] = None,
): ):
self.validate_parsed_serve_args(args) self.validate_parsed_serve_args(args)
self.port = args.port self.port = args.port
@ -356,8 +355,11 @@ class ProxyServer:
prefill_instances=[] if args.prefill is None else args.prefill, prefill_instances=[] if args.prefill is None else args.prefill,
decode_instances=[] if args.decode is None else args.decode, decode_instances=[] if args.decode is None else args.decode,
model=args.model, model=args.model,
scheduling_policy=(scheduling_policy if scheduling_policy scheduling_policy=(
is not None else RoundRobinSchedulingPolicy()), scheduling_policy
if scheduling_policy is not None
else RoundRobinSchedulingPolicy()
),
custom_create_completion=create_completion, custom_create_completion=create_completion,
custom_create_chat_completion=create_chat_completion, custom_create_chat_completion=create_chat_completion,
) )
@ -382,11 +384,9 @@ class ProxyServer:
ipaddress.ip_address(host) ipaddress.ip_address(host)
port = int(port) port = int(port)
if not (0 < port < 65536): if not (0 < port < 65536):
raise ValueError( raise ValueError(f"Invalid port number in instance: {instance}")
f"Invalid port number in instance: {instance}")
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(f"Invalid instance {instance}: {str(e)}") from e
f"Invalid instance {instance}: {str(e)}") from e
def verify_model_config(self, instances: list, model: str) -> None: def verify_model_config(self, instances: list, model: str) -> None:
model_suffix = model.split("/")[-1] model_suffix = model.split("/")[-1]
@ -399,12 +399,14 @@ class ProxyServer:
if model_cur_suffix != model_suffix: if model_cur_suffix != model_suffix:
raise ValueError( raise ValueError(
f"{instance} serves a different model: " f"{instance} serves a different model: "
f"{model_cur} != {model}") f"{model_cur} != {model}"
)
else: else:
raise ValueError(f"Cannot get model id from {instance}!") raise ValueError(f"Cannot get model id from {instance}!")
except requests.RequestException as e: except requests.RequestException as e:
raise ValueError( raise ValueError(
f"Error communicating with {instance}: {str(e)}") from e f"Error communicating with {instance}: {str(e)}"
) from e
def run_server(self): def run_server(self):
app = FastAPI() app = FastAPI()
@ -417,11 +419,7 @@ class ProxyServer:
def parse_args(): def parse_args():
# Todo: allow more config # Todo: allow more config
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
parser.add_argument("--model", parser.add_argument("--model", "-m", type=str, required=True, help="Model name")
"-m",
type=str,
required=True,
help="Model name")
parser.add_argument( parser.add_argument(
"--prefill", "--prefill",

View File

@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
""" """
import argparse import argparse
import gradio as gr import gradio as gr
@ -24,16 +25,12 @@ from openai import OpenAI
def format_history_to_openai(history): def format_history_to_openai(history):
history_openai_format = [{ history_openai_format = [
"role": "system", {"role": "system", "content": "You are a great AI assistant."}
"content": "You are a great AI assistant." ]
}]
for human, assistant in history: for human, assistant in history:
history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({ history_openai_format.append({"role": "assistant", "content": assistant})
"role": "assistant",
"content": assistant
})
return history_openai_format return history_openai_format
@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
temperature=temp, temperature=temp,
stream=True, stream=True,
extra_body={ extra_body={
'repetition_penalty': "repetition_penalty": 1,
1, "stop_token_ids": [int(id.strip()) for id in stop_token_ids.split(",")]
'stop_token_ids': if stop_token_ids
[int(id.strip()) else [],
for id in stop_token_ids.split(',')] if stop_token_ids else [] },
}) )
# Collect all chunks and concatenate them into a full message # Collect all chunks and concatenate them into a full message
full_message = "" full_message = ""
for chunk in stream: for chunk in stream:
full_message += (chunk.choices[0].delta.content or "") full_message += chunk.choices[0].delta.content or ""
# Return the full message as a single response # Return the full message as a single response
return full_message return full_message
@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Chatbot Interface with Customizable Parameters') description="Chatbot Interface with Customizable Parameters"
parser.add_argument('--model-url', )
type=str, parser.add_argument(
default='http://localhost:8000/v1', "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
help='Model URL') )
parser.add_argument('-m', parser.add_argument(
'--model', "-m", "--model", type=str, required=True, help="Model name for the chatbot"
type=str, )
required=True, parser.add_argument(
help='Model name for the chatbot') "--temp", type=float, default=0.8, help="Temperature for text generation"
parser.add_argument('--temp', )
type=float, parser.add_argument(
default=0.8, "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
help='Temperature for text generation') )
parser.add_argument('--stop-token-ids',
type=str,
default='',
help='Comma-separated stop token IDs')
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
return parser.parse_args() return parser.parse_args()
def build_gradio_interface(client, model_name, temp, stop_token_ids): def build_gradio_interface(client, model_name, temp, stop_token_ids):
def chat_predict(message, history): def chat_predict(message, history):
return predict(message, history, client, model_name, temp, return predict(message, history, client, model_name, temp, stop_token_ids)
stop_token_ids)
return gr.ChatInterface(fn=chat_predict, return gr.ChatInterface(
title="Chatbot Interface", fn=chat_predict,
description="A simple chatbot powered by vLLM") title="Chatbot Interface",
description="A simple chatbot powered by vLLM",
)
def main(): def main():
@ -113,12 +106,13 @@ def main():
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
# Define the Gradio chatbot interface using the predict function # Define the Gradio chatbot interface using the predict function
gradio_interface = build_gradio_interface(client, args.model, args.temp, gradio_interface = build_gradio_interface(
args.stop_token_ids) client, args.model, args.temp, args.stop_token_ids
)
gradio_interface.queue().launch(server_name=args.host, gradio_interface.queue().launch(
server_port=args.port, server_name=args.host, server_port=args.port, share=True
share=True) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
""" """
import argparse import argparse
import json import json
@ -31,14 +32,11 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(
decode_unicode=False, chunk_size=8192, decode_unicode=False, delimiter=b"\n"
delimiter=b"\n"): ):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
@ -48,10 +46,10 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("# vLLM text completion demo\n") gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input", inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
placeholder="Enter text and press ENTER") outputbox = gr.Textbox(
outputbox = gr.Textbox(label="Output", label="Output", placeholder="Generated result from the model"
placeholder="Generated result from the model") )
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
@ -60,17 +58,15 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", parser.add_argument(
type=str, "--model-url", type=str, default="http://localhost:8000/generate"
default="http://localhost:8000/generate") )
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
demo = build_demo() demo = build_demo()
demo.queue().launch(server_name=args.host, demo.queue().launch(server_name=args.host, server_port=args.port, share=True)
server_port=args.port,
share=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5,6 +5,7 @@ Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base run: vllm serve BAAI/bge-reranker-base
""" """
import json import json
import requests import requests
@ -14,14 +15,13 @@ url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
data = { data = {
"model": "model": "BAAI/bge-reranker-base",
"BAAI/bge-reranker-base", "query": "What is the capital of France?",
"query":
"What is the capital of France?",
"documents": [ "documents": [
"The capital of Brazil is Brasilia.", "The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "Horses and cows are both animals" "The capital of France is Paris.",
] "Horses and cows are both animals",
],
} }

View File

@ -9,17 +9,14 @@ from msgspec.msgpack import Decoder
# #
# Types copied from vllm.distributed.kv_events # Types copied from vllm.distributed.kv_events
# #
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
gc=False):
ts: float ts: float
events: list[Any] events: list[Any]
class KVCacheEvent(msgspec.Struct, class KVCacheEvent(
array_like=True, msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
omit_defaults=True, ):
gc=False,
tag=True):
"""Base class for all KV cache-related events""" """Base class for all KV cache-related events"""
@ -77,8 +74,9 @@ def main():
if last_seq >= 0 and seq > last_seq + 1: if last_seq >= 0 and seq > last_seq + 1:
missed = seq - last_seq - 1 missed = seq - last_seq - 1
print(f"Missed {missed} messages" print(
f" (last: {last_seq}, current: {seq})") f"Missed {missed} messages (last: {last_seq}, current: {seq})"
)
replay.send((last_seq + 1).to_bytes(8, "big")) replay.send((last_seq + 1).to_bytes(8, "big"))

View File

@ -12,26 +12,22 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {"role": "user", "content": "Who won the world series in 2020?"},
}, { {
"role": "user", "role": "assistant",
"content": "Who won the world series in 2020?" "content": "The Los Angeles Dodgers won the World Series in 2020.",
}, { },
"role": "assistant", {"role": "user", "content": "Where was it played?"},
"content": "The Los Angeles Dodgers won the World Series in 2020." ]
}, {
"role": "user",
"content": "Where was it played?"
}]
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server") parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream", parser.add_argument(
action="store_true", "--stream", action="store_true", help="Enable streaming response"
help="Enable streaming response") )
return parser.parse_args() return parser.parse_args()

View File

@ -43,7 +43,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
with requests.get(content_url) as response: with requests.get(content_url) as response:
response.raise_for_status() response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8') result = base64.b64encode(response.content).decode("utf-8")
return result return result
@ -51,10 +51,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
# Text-only inference # Text-only inference
def run_text_only(model: str) -> None: def run_text_only(model: str) -> None:
chat_completion = client.chat.completions.create( chat_completion = client.chat.completions.create(
messages=[{ messages=[{"role": "user", "content": "What's the capital of France?"}],
"role": "user",
"content": "What's the capital of France?"
}],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -65,26 +62,21 @@ def run_text_only(model: str) -> None:
# Single-image input inference # Single-image input inference
def run_single_image(model: str) -> None: def run_single_image(model: str) -> None:
## Use image url in the payload ## Use image url in the payload
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this image?"},
"type": "text", {
"text": "What's in this image?" "type": "image_url",
}, "image_url": {"url": image_url},
{
"type": "image_url",
"image_url": {
"url": image_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -95,22 +87,18 @@ def run_single_image(model: str) -> None:
## Use base64 encoded image in the payload ## Use base64 encoded image in the payload
image_base64 = encode_base64_content_from_url(image_url) image_base64 = encode_base64_content_from_url(image_url)
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this image?"},
"type": "text", {
"text": "What's in this image?" "type": "image_url",
}, "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -124,28 +112,22 @@ def run_multi_image(model: str) -> None:
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What are the animals in these images?"},
"type": "text", {
"text": "What are the animals in these images?" "type": "image_url",
}, "image_url": {"url": image_url_duck},
{
"type": "image_url",
"image_url": {
"url": image_url_duck
}, },
}, {
{ "type": "image_url",
"type": "image_url", "image_url": {"url": image_url_lion},
"image_url": {
"url": image_url_lion
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -161,22 +143,18 @@ def run_video(model: str) -> None:
## Use video url in the payload ## Use video url in the payload
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this video?"},
"type": "text", {
"text": "What's in this video?" "type": "video_url",
}, "video_url": {"url": video_url},
{
"type": "video_url",
"video_url": {
"url": video_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -186,22 +164,18 @@ def run_video(model: str) -> None:
## Use base64 encoded video in the payload ## Use base64 encoded video in the payload
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this video?"},
"type": "text", {
"text": "What's in this video?" "type": "video_url",
}, "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
{
"type": "video_url",
"video_url": {
"url": f"data:video/mp4;base64,{video_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -219,24 +193,22 @@ def run_audio(model: str) -> None:
# OpenAI-compatible schema (`input_audio`) # OpenAI-compatible schema (`input_audio`)
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "input_audio",
}, "input_audio": {
{ # Any format supported by librosa is supported
"type": "input_audio", "data": audio_base64,
"input_audio": { "format": "wav",
# Any format supported by librosa is supported },
"data": audio_base64,
"format": "wav"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -246,23 +218,21 @@ def run_audio(model: str) -> None:
# HTTP URL # HTTP URL
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "audio_url",
}, "audio_url": {
{ # Any format supported by librosa is supported
"type": "audio_url", "url": audio_url
"audio_url": { },
# Any format supported by librosa is supported
"url": audio_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -272,23 +242,21 @@ def run_audio(model: str) -> None:
# base64 URL # base64 URL
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "audio_url",
}, "audio_url": {
{ # Any format supported by librosa is supported
"type": "audio_url", "url": f"data:audio/ogg;base64,{audio_base64}"
"audio_url": { },
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
@ -308,14 +276,17 @@ example_function_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using OpenAI client for online serving with ' description="Demo on using OpenAI client for online serving with "
'multimodal language models served with vLLM.') "multimodal language models served with vLLM."
parser.add_argument('--chat-type', )
'-c', parser.add_argument(
type=str, "--chat-type",
default="single-image", "-c",
choices=list(example_function_map.keys()), type=str,
help='Conversation type with multimodal data.') default="single-image",
choices=list(example_function_map.keys()),
help="Conversation type with multimodal data.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -16,6 +16,7 @@ vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \
--chat-template examples/tool_chat_template_hermes.jinja \ --chat-template examples/tool_chat_template_hermes.jinja \
--enable-auto-tool-choice --tool-call-parser hermes --enable-auto-tool-choice --tool-call-parser hermes
""" """
import json import json
from typing import Any from typing import Any
@ -25,55 +26,55 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
tools = [{ properties = {
"type": "function", "city": {
"function": { "type": "string",
"name": "get_current_weather", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description": "Get the current weather in a given location", },
"parameters": { "state": {
"type": "object", "type": "string",
"properties": { "description": "the two-letter abbreviation for the state that the city is"
"city": { " in, e.g. 'CA' which would mean 'California'",
"type": },
"string", "unit": {
"description": "type": "string",
"The city to find the weather for, e.g. 'San Francisco'" "description": "The unit to fetch the temperature in",
}, "enum": ["celsius", "fahrenheit"],
"state": { },
"type": }
"string",
"description": tools = [
"the two-letter abbreviation for the state that the city is" {
" in, e.g. 'CA' which would mean 'California'" "type": "function",
}, "function": {
"unit": { "name": "get_current_weather",
"type": "string", "description": "Get the current weather in a given location",
"description": "The unit to fetch the temperature in", "parameters": {
"enum": ["celsius", "fahrenheit"] "type": "object",
} "properties": properties,
"required": ["city", "state", "unit"],
}, },
"required": ["city", "state", "unit"] },
}
} }
}] ]
messages = [{ messages = [
"role": "user", {"role": "user", "content": "Hi! How are you doing today?"},
"content": "Hi! How are you doing today?" {"role": "assistant", "content": "I'm doing well! How can I help you?"},
}, { {
"role": "assistant", "role": "user",
"content": "I'm doing well! How can I help you?" "content": (
}, { "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
"role": ),
"user", },
"content": ]
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " return (
"partly cloudly, with highs in the 90's.") "The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
def handle_tool_calls_stream( def handle_tool_calls_stream(
@ -82,10 +83,9 @@ def handle_tool_calls_stream(
model: str, model: str,
tools: list[dict[str, Any]], tools: list[dict[str, Any]],
) -> list[Any]: ) -> list[Any]:
tool_calls_stream = client.chat.completions.create(messages=messages, tool_calls_stream = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools, stream=True
tools=tools, )
stream=True)
chunks = [] chunks = []
print("chunks: ") print("chunks: ")
for chunk in tool_calls_stream: for chunk in tool_calls_stream:
@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
tool_call = chunk.choices[0].delta.tool_calls[0] tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx: if tool_call.index != tool_call_idx:
if tool_call_idx >= 0: if tool_call_idx >= 0:
print(f"streamed tool call arguments: " print(f"streamed tool call arguments: {arguments[tool_call_idx]}")
f"{arguments[tool_call_idx]}")
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("") arguments.append("")
if tool_call.id: if tool_call.id:
@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
if tool_call.function: if tool_call.function:
if tool_call.function.name: if tool_call.function.name:
print( print(f"streamed tool call name: {tool_call.function.name}")
f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments: if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments arguments[tool_call_idx] += tool_call.function.arguments
@ -136,9 +134,9 @@ def main():
models = client.models.list() models = client.models.list()
model = models.data[0].id model = models.data[0].id
chat_completion = client.chat.completions.create(messages=messages, chat_completion = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools
tools=tools) )
print("-" * 70) print("-" * 70)
print("Chat completion results:") print("Chat completion results:")
@ -158,10 +156,12 @@ def main():
print("-" * 70) print("-" * 70)
# Add tool call results to the conversation # Add tool call results to the conversation
messages.append({ messages.append(
"role": "assistant", {
"tool_calls": chat_completion.choices[0].message.tool_calls "role": "assistant",
}) "tool_calls": chat_completion.choices[0].message.tool_calls,
}
)
# Now, simulate a tool call # Now, simulate a tool call
available_tools = {"get_current_weather": get_current_weather} available_tools = {"get_current_weather": get_current_weather}
@ -172,17 +172,18 @@ def main():
args = json.loads(call.function.arguments) args = json.loads(call.function.arguments)
result = tool_to_call(**args) result = tool_to_call(**args)
print("tool_to_call result: ", result) print("tool_to_call result: ", result)
messages.append({ messages.append(
"role": "tool", {
"content": result, "role": "tool",
"tool_call_id": call.id, "content": result,
"name": call.function.name "tool_call_id": call.id,
}) "name": call.function.name,
}
)
chat_completion_2 = client.chat.completions.create(messages=messages, chat_completion_2 = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools, stream=False
tools=tools, )
stream=False)
print("Chat completion2 results:") print("Chat completion2 results:")
print(chat_completion_2) print(chat_completion_2)
print("-" * 70) print("-" * 70)

View File

@ -28,18 +28,16 @@ tools = [
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": "The city to find the weather for"
"description":
"The city to find the weather for"
", e.g. 'San Francisco'", ", e.g. 'San Francisco'",
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": (
"description": "the two-letter abbreviation for the state that the "
"the two-letter abbreviation for the state that the " "city is in, e.g. 'CA' which would mean 'California'"
"city is in, e.g. 'CA' which would mean 'California'", ),
}, },
"unit": { "unit": {
"type": "string", "type": "string",
@ -60,22 +58,20 @@ tools = [
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": (
"description": "The city to get the forecast for, e.g. 'New York'"
"The city to get the forecast for, e.g. 'New York'", ),
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": (
"description": "The two-letter abbreviation for the state, e.g. 'NY'"
"The two-letter abbreviation for the state, e.g. 'NY'", ),
}, },
"days": { "days": {
"type": "type": "integer",
"integer", "description": "Number of days to get the forecast for (1-7)",
"description":
"Number of days to get the forecast for (1-7)",
}, },
"unit": { "unit": {
"type": "string", "type": "string",
@ -90,19 +86,11 @@ tools = [
] ]
messages = [ messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{ {
"role": "user", "role": "user",
"content": "Hi! How are you doing today?" "content": "Can you tell me what the current weather is in Dallas \
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Dallas \
and the forecast for the next 5 days, in fahrenheit?", and the forecast for the next 5 days, in fahrenheit?",
}, },
] ]
@ -123,17 +111,16 @@ def main():
model=model, model=model,
tools=tools, tools=tools,
tool_choice="required", tool_choice="required",
stream=True # Enable streaming response stream=True, # Enable streaming response
) )
for chunk in chat_completion: for chunk in chat_completion:
if chunk.choices and chunk.choices[0].delta.tool_calls: if chunk.choices and chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls) print(chunk.choices[0].delta.tool_calls)
chat_completion = client.chat.completions.create(messages=messages, chat_completion = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools, tool_choice="required"
tools=tools, )
tool_choice="required")
print(chat_completion.choices[0].message.tool_calls) print(chat_completion.choices[0].message.tool_calls)

View File

@ -20,10 +20,9 @@ openai_api_base = "http://localhost:8000/v1"
def guided_choice_completion(client: OpenAI, model: str): def guided_choice_completion(client: OpenAI, model: str):
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
"content": "Classify this sentiment: vLLM is wonderful!" ],
}],
extra_body={"guided_choice": ["positive", "negative"]}, extra_body={"guided_choice": ["positive", "negative"]},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
@ -31,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str):
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"End in .com and new line. Example result:" "Generate an email address for Alan Turing, who works in Enigma."
"alan.turing@enigma.com\n") "End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
extra_body={ }
"guided_regex": r"\w+@\w+\.com\n", ],
"stop": ["\n"] extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
@ -66,14 +66,18 @@ class CarDescription(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
@ -95,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
@ -110,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str):
# Extra backend options # Extra backend options
def extra_backend_options_completion(client: OpenAI, model: str): def extra_backend_options_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"End in .com and new line. Example result:" "Generate an email address for Alan Turing, who works in Enigma."
"alan.turing@enigma.com\n") "End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
try: try:
# The guided_decoding_disable_fallback option forces vLLM to use # The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why # xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={ extra_body={
"guided_regex": r"\w+@\w+\.com\n", "guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"], "stop": ["\n"],

View File

@ -17,11 +17,10 @@ def main():
api_key=openai_api_key, api_key=openai_api_key,
) )
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": "content": """
"""
You have access to the following function to retrieve the weather in a city: You have access to the following function to retrieve the weather in a city:
{ {
@ -58,29 +57,28 @@ You are a helpful assistant.
Given the previous instructions, what is the weather in New York City, Boston, Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco? and San Francisco?
""" """,
}] }
]
response = client.chat.completions.create( response = client.chat.completions.create(
model=client.models.list().data[0].id, model=client.models.list().data[0].id,
messages=messages, messages=messages,
response_format={ response_format={
"type": "type": "structural_tag",
"structural_tag", "structures": [
"structures": [{ {
"begin": "<function=get_weather>", "begin": "<function=get_weather>",
"schema": { "schema": {
"type": "object", "type": "object",
"properties": { "properties": {"city": {"type": "string"}},
"city": { },
"type": "string" "end": "</function>",
} }
} ],
}, "triggers": ["<function="],
"end": "</function>" },
}], )
"triggers": ["<function="]
})
print(response) print(response)

View File

@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1"
def print_completion_details(completion): def print_completion_details(completion):
print("reasoning_content: ", print("reasoning_content: ", completion.choices[0].message.reasoning_content)
completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content) print("content: ", completion.choices[0].message.content)
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("What is the capital of France?") prompt = "What is the capital of France?"
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={ extra_body={
"guided_regex": "(Paris|London)", "guided_regex": "(Paris|London)",
}, },
@ -57,13 +58,15 @@ class People(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = People.model_json_schema() json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.") prompt = "Generate a JSON with the name and age of one random person."
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
@ -86,14 +89,18 @@ class CarDescription(BaseModel):
def guided_car_json_completion(client: OpenAI, model: str): def guided_car_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
""" """
# This may be very slow https://github.com/vllm-project/vllm/issues/12122 # This may be very slow https://github.com/vllm-project/vllm/issues/12122
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
print_completion_details(completion) print_completion_details(completion)

View File

@ -20,9 +20,11 @@ from openai import OpenAI
# Now, simulate a tool call # Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " return (
"partly cloudly, with highs in the 90's.") "The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather} available_tools = {"get_current_weather": get_current_weather}
@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather}
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
tools = [{ properties = {
"type": "function", "city": {
"function": { "type": "string",
"name": "get_current_weather", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description": "Get the current weather in a given location", },
"parameters": { "state": {
"type": "object", "type": "string",
"properties": { "description": "the two-letter abbreviation for the state that the city is"
"city": { " in, e.g. 'CA' which would mean 'California'",
"type": },
"string", "unit": {
"description": "type": "string",
"The city to find the weather for, e.g. 'San Francisco'" "description": "The unit to fetch the temperature in",
}, "enum": ["celsius", "fahrenheit"],
"state": { },
"type": }
"string",
"description": tools = [
"the two-letter abbreviation for the state that the city is" {
" in, e.g. 'CA' which would mean 'California'" "type": "function",
}, "function": {
"unit": { "name": "get_current_weather",
"type": "string", "description": "Get the current weather in a given location",
"description": "The unit to fetch the temperature in", "parameters": {
"enum": ["celsius", "fahrenheit"] "type": "object",
} "properties": properties,
"required": ["city", "state", "unit"],
}, },
"required": ["city", "state", "unit"] },
}
} }
}] ]
messages = [{ messages = [
"role": "user", {"role": "user", "content": "Hi! How are you doing today?"},
"content": "Hi! How are you doing today?" {"role": "assistant", "content": "I'm doing well! How can I help you?"},
}, { {
"role": "assistant", "role": "user",
"content": "I'm doing well! How can I help you?" "content": (
}, { "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
"role": ),
"user", },
"content": ]
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
def extract_reasoning_and_calls(chunks: list): def extract_reasoning_and_calls(chunks: list):
@ -110,73 +110,55 @@ def main():
models = client.models.list() models = client.models.list()
model = models.data[0].id model = models.data[0].id
print( print("---------Full Generate With Automatic Function Calling-------------")
"---------Full Generate With Automatic Function Calling-------------") tool_calls = client.chat.completions.create(
tool_calls = client.chat.completions.create(messages=messages, messages=messages, model=model, tools=tools
model=model,
tools=tools)
print(
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
) )
print(f"function name: " print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
f"{tool_calls.choices[0].message.tool_calls[0].function.name}") print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
print(f"function arguments: "
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
print( print(
"----------Stream Generate With Automatic Function Calling-----------") f"function arguments: "
tool_calls_stream = client.chat.completions.create(messages=messages, f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}"
model=model, )
tools=tools,
stream=True) print("----------Stream Generate With Automatic Function Calling-----------")
tool_calls_stream = client.chat.completions.create(
messages=messages, model=model, tools=tools, stream=True
)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")
print( print("----------Full Generate With Named Function Calling-----------------")
"----------Full Generate With Named Function Calling-----------------") tool_calls = client.chat.completions.create(
tool_calls = client.chat.completions.create(messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", )
"function": {
"name":
"get_current_weather"
}
})
tool_call = tool_calls.choices[0].message.tool_calls[0].function tool_call = tool_calls.choices[0].message.tool_calls[0].function
print( print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
)
print(f"function name: {tool_call.name}") print(f"function name: {tool_call.name}")
print(f"function arguments: {tool_call.arguments}") print(f"function arguments: {tool_call.arguments}")
print( print("----------Stream Generate With Named Function Calling--------------")
"----------Stream Generate With Named Function Calling--------------")
tool_calls_stream = client.chat.completions.create( tool_calls_stream = client.chat.completions.create(
messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", stream=True,
"function": { )
"name": "get_current_weather"
}
},
stream=True)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")

View File

@ -45,12 +45,12 @@ def main():
# Round 2 # Round 2
messages.append({"role": "assistant", "content": content}) messages.append({"role": "assistant", "content": content})
messages.append({ messages.append(
"role": {
"user", "role": "user",
"content": "content": "How many Rs are there in the word 'strawberry'?",
"How many Rs are there in the word 'strawberry'?", }
}) )
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content

View File

@ -43,9 +43,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}` # For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model, stream = client.chat.completions.create(model=model, messages=messages, stream=True)
messages=messages,
stream=True)
print("client: Start streaming chat completions...") print("client: Start streaming chat completions...")
printed_reasoning_content = False printed_reasoning_content = False

View File

@ -14,26 +14,17 @@ def vlm2vec():
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
json={ json={
"model": "model": "TIGER-Lab/VLM2Vec-Full",
"TIGER-Lab/VLM2Vec-Full", "messages": [
"messages": [{ {
"role": "role": "user",
"user", "content": [
"content": [ {"type": "image_url", "image_url": {"url": image_url}},
{ {"type": "text", "text": "Represent the given image."},
"type": "image_url", ],
"image_url": { }
"url": image_url ],
} "encoding_format": "float",
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
}, },
) )
response.raise_for_status() response.raise_for_status()
@ -45,19 +36,20 @@ def vlm2vec():
def dse_qwen2_vl(inp: dict): def dse_qwen2_vl(inp: dict):
# Embedding an Image # Embedding an Image
if inp["type"] == "image": if inp["type"] == "image":
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": inp["image_url"], "image_url": {
} "url": inp["image_url"],
}, { },
"type": "text", },
"text": "What is shown in this image?" {"type": "text", "text": "What is shown in this image?"},
}] ],
}] }
]
# Embedding a Text Query # Embedding a Text Query
else: else:
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
@ -66,23 +58,21 @@ def dse_qwen2_vl(inp: dict):
image_placeholder = Image.new("RGB", (56, 56)) image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png") image_placeholder.save(buffer, "png")
buffer.seek(0) buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode('utf-8') image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}", "url": f"data:image/jpeg;base64,{image_placeholder}",
} },
}, },
{ {"type": "text", "text": f"Query: {inp['content']}"},
"type": "text", ],
"text": f"Query: {inp['content']}" }
}, ]
]
}]
response = requests.post( response = requests.post(
"http://localhost:8000/v1/embeddings", "http://localhost:8000/v1/embeddings",
@ -101,12 +91,15 @@ def dse_qwen2_vl(inp: dict):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve " "Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embed before running this.") "the model with --task embed before running this."
parser.add_argument("--model", )
type=str, parser.add_argument(
choices=["vlm2vec", "dse_qwen2_vl"], "--model",
required=True, type=str,
help="Which model to call.") choices=["vlm2vec", "dse_qwen2_vl"],
required=True,
help="Which model to call.",
)
return parser.parse_args() return parser.parse_args()
@ -114,16 +107,20 @@ def main(args):
if args.model == "vlm2vec": if args.model == "vlm2vec":
vlm2vec() vlm2vec()
elif args.model == "dse_qwen2_vl": elif args.model == "dse_qwen2_vl":
dse_qwen2_vl({ dse_qwen2_vl(
"type": "image", {
"image_url": image_url, "type": "image",
}) "image_url": image_url,
dse_qwen2_vl({ }
"type": "text", )
"content": "What is the weather like today?", dse_qwen2_vl(
}) {
"type": "text",
"content": "What is the weather like today?",
}
)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)

View File

@ -16,9 +16,7 @@ def parse_args():
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost") parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000) parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model", parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args() return parse.parse_args()

View File

@ -11,9 +11,9 @@ openai_api_base = "http://localhost:8000/v1"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server") parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream", parser.add_argument(
action="store_true", "--stream", action="store_true", help="Enable streaming response"
help="Enable streaming response") )
return parser.parse_args() return parser.parse_args()
@ -34,7 +34,8 @@ def main(args):
echo=False, echo=False,
n=2, n=2,
stream=args.stream, stream=args.stream,
logprobs=3) logprobs=3,
)
print("-" * 50) print("-" * 50)
print("Completion results:") print("Completion results:")

View File

@ -4,6 +4,7 @@ Example online usage of Score API.
Run `vllm serve <model> --task score` to start up the server in vLLM. Run `vllm serve <model> --task score` to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
@ -38,9 +39,7 @@ def main(args):
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = [ text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a list:") print("\nPrompt when text_1 is string and text_2 is a list:")
@ -48,12 +47,8 @@ def main(args):
print("\nScore Response:") print("\nScore Response:")
pprint.pprint(score_response.json()) pprint.pprint(score_response.json())
text_1 = [ text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
"What is the capital of Brazil?", "What is the capital of France?" text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
]
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url) score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both lists:") print("\nPrompt when text_1 and text_2 are both lists:")

View File

@ -21,7 +21,7 @@ def main():
# ruff: noqa: E501 # ruff: noqa: E501
input=[ input=[
"Hello my name is", "Hello my name is",
"The best thing about vLLM is that it supports many different models" "The best thing about vLLM is that it supports many different models",
], ],
model=model, model=model,
) )

View File

@ -5,6 +5,7 @@ Example online usage of Pooling API.
Run `vllm serve <model> --task <embed|classify|reward|score>` Run `vllm serve <model> --task <embed|classify|reward|score>`
to start up the server in vLLM. to start up the server in vLLM.
""" """
import argparse import argparse
import pprint import pprint
@ -21,9 +22,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
type=str,
default="jason9693/Qwen2.5-1.5B-apeach")
return parser.parse_args() return parser.parse_args()
@ -42,15 +41,13 @@ def main(args):
# Input like Chat API # Input like Chat API
prompt = { prompt = {
"model": "model": model_name,
model_name, "messages": [
"messages": [{ {
"role": "user", "role": "user",
"content": [{ "content": [{"type": "text", "text": "vLLM is great!"}],
"type": "text", }
"text": "vLLM is great!" ],
}],
}]
} }
pooling_response = post_http_request(prompt=prompt, api_url=api_url) pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("Pooling Response:") print("Pooling Response:")

View File

@ -7,8 +7,8 @@ from openai import OpenAI
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
winning_call = AudioAsset('winning_call').get_local_path() winning_call = AudioAsset("winning_call").get_local_path()
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
@ -31,7 +31,8 @@ def sync_openai():
extra_body=dict( extra_body=dict(
seed=4419, seed=4419,
repetition_penalty=1.3, repetition_penalty=1.3,
)) ),
)
print("transcription result:", transcription.text) print("transcription result:", transcription.text)
@ -42,33 +43,30 @@ sync_openai()
async def stream_openai_response(): async def stream_openai_response():
data = { data = {
"language": "en", "language": "en",
'stream': True, "stream": True,
"model": "openai/whisper-large-v3", "model": "openai/whisper-large-v3",
} }
url = openai_api_base + "/audio/transcriptions" url = openai_api_base + "/audio/transcriptions"
headers = {"Authorization": f"Bearer {openai_api_key}"} headers = {"Authorization": f"Bearer {openai_api_key}"}
print("transcription result:", end=' ') print("transcription result:", end=" ")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f: with open(str(winning_call), "rb") as f:
async with client.stream('POST', async with client.stream(
url, "POST", url, files={"file": f}, data=data, headers=headers
files={'file': f}, ) as response:
data=data,
headers=headers) as response:
async for line in response.aiter_lines(): async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: ' # Each line is a JSON object prefixed with 'data: '
if line: if line:
if line.startswith('data: '): if line.startswith("data: "):
line = line[len('data: '):] line = line[len("data: ") :]
# Last chunk, stream ends # Last chunk, stream ends
if line.strip() == '[DONE]': if line.strip() == "[DONE]":
break break
# Parse the JSON response # Parse the JSON response
chunk = json.loads(line) chunk = json.loads(line)
# Extract and print the content # Extract and print the content
content = chunk['choices'][0].get('delta', content = chunk["choices"][0].get("delta", {}).get("content")
{}).get('content') print(content, end="")
print(content, end='')
# Run the asynchronous function # Run the asynchronous function

View File

@ -1,14 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import requests import requests
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
OTLPSpanExporter)
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (BatchSpanProcessor, from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
ConsoleSpanExporter)
from opentelemetry.trace import SpanKind, set_tracer_provider from opentelemetry.trace import SpanKind, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import ( from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
TraceContextTextMapPropagator)
trace_provider = TracerProvider() trace_provider = TracerProvider()
set_tracer_provider(trace_provider) set_tracer_provider(trace_provider)

View File

@ -26,6 +26,7 @@ Dependencies:
- torch - torch
- openai - openai
""" """
import base64 import base64
import io import io
@ -44,17 +45,13 @@ def main():
# Transformers # Transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
transformers_model = transformers.AutoModelForCausalLM.from_pretrained( transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
model_name)
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
chat = [{ chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
"role": "user", token_ids = tokenizer.apply_chat_template(
"content": "Please tell me about the capital of France." chat, add_generation_prompt=True, return_tensors="pt"
}] )
token_ids = tokenizer.apply_chat_template(chat,
add_generation_prompt=True,
return_tensors='pt')
embedding_layer = transformers_model.get_input_embeddings() embedding_layer = transformers_model.get_input_embeddings()
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
@ -64,7 +61,7 @@ def main():
torch.save(prompt_embeds, buffer) torch.save(prompt_embeds, buffer)
buffer.seek(0) buffer.seek(0)
binary_data = buffer.read() binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode('utf-8') encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
completion = client.completions.create( completion = client.completions.create(
model=model_name, model=model_name,
@ -75,7 +72,8 @@ def main():
temperature=0.0, temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the # NOTE: The OpenAI client allows passing in extra JSON body via the
# `extra_body` argument. # `extra_body` argument.
extra_body={"prompt_embeds": encoded_embeds}) extra_body={"prompt_embeds": encoded_embeds},
)
print("-" * 30) print("-" * 30)
print(completion.choices[0].text) print(completion.choices[0].text)

View File

@ -28,9 +28,7 @@ llm_config = LLMConfig(
}, },
# Change to the accelerator type of the node # Change to the accelerator type of the node
accelerator_type="H100", accelerator_type="H100",
runtime_env={"env_vars": { runtime_env={"env_vars": {"VLLM_USE_V1": "1"}},
"VLLM_USE_V1": "1"
}},
# Customize engine arguments as needed (e.g. vLLM engine kwargs) # Customize engine arguments as needed (e.g. vLLM engine kwargs)
engine_kwargs={ engine_kwargs={
"tensor_parallel_size": 8, "tensor_parallel_size": 8,

View File

@ -55,7 +55,7 @@ def load_and_split_documents(config: dict[str, Any]):
Load and split documents from web URL Load and split documents from web URL
""" """
try: try:
loader = WebBaseLoader(web_paths=(config["url"], )) loader = WebBaseLoader(web_paths=(config["url"],))
docs = loader.load() docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
@ -121,64 +121,71 @@ def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
""" """
Set up question answering chain Set up question answering chain
""" """
return ({ return (
"context": retriever | format_docs, {
"question": RunnablePassthrough(), "context": retriever | format_docs,
} "question": RunnablePassthrough(),
| prompt }
| llm | prompt
| StrOutputParser()) | llm
| StrOutputParser()
)
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
""" """
Parse command line arguments Parse command line arguments
""" """
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') parser = argparse.ArgumentParser(description="RAG with vLLM and langchain")
# Add command line arguments # Add command line arguments
parser.add_argument('--vllm-api-key',
default="EMPTY",
help='API key for vLLM compatible services')
parser.add_argument('--vllm-embedding-endpoint',
default="http://localhost:8000/v1",
help='Base URL for embedding service')
parser.add_argument('--vllm-chat-endpoint',
default="http://localhost:8001/v1",
help='Base URL for chat service')
parser.add_argument('--uri',
default="./milvus.db",
help='URI for Milvus database')
parser.add_argument( parser.add_argument(
'--url', "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
default=("https://docs.vllm.ai/en/latest/getting_started/" )
"quickstart.html"), parser.add_argument(
help='URL of the document to process') "--vllm-embedding-endpoint",
parser.add_argument('--embedding-model', default="http://localhost:8000/v1",
default="ssmits/Qwen2-7B-Instruct-embed-base", help="Base URL for embedding service",
help='Model name for embeddings') )
parser.add_argument('--chat-model', parser.add_argument(
default="qwen/Qwen1.5-0.5B-Chat", "--vllm-chat-endpoint",
help='Model name for chat') default="http://localhost:8001/v1",
parser.add_argument('-i', help="Base URL for chat service",
'--interactive', )
action='store_true', parser.add_argument("--uri", default="./milvus.db", help="URI for Milvus database")
help='Enable interactive Q&A mode') parser.add_argument(
parser.add_argument('-k', "--url",
'--top-k', default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
type=int, help="URL of the document to process",
default=3, )
help='Number of top results to retrieve') parser.add_argument(
parser.add_argument('-c', "--embedding-model",
'--chunk-size', default="ssmits/Qwen2-7B-Instruct-embed-base",
type=int, help="Model name for embeddings",
default=1000, )
help='Chunk size for document splitting') parser.add_argument(
parser.add_argument('-o', "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
'--chunk-overlap', )
type=int, parser.add_argument(
default=200, "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Chunk overlap for document splitting') )
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
parser.add_argument(
"-c",
"--chunk-size",
type=int,
default=1000,
help="Chunk size for document splitting",
)
parser.add_argument(
"-o",
"--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
return parser return parser
@ -198,7 +205,7 @@ def init_config(args: Namespace):
"url": args.url, "url": args.url,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
@ -230,7 +237,7 @@ def main():
while True: while True:
question = input("\nPlease enter your question: ") question = input("\nPlease enter your question: ")
if question.lower() in ['q', 'quit']: if question.lower() in ["q", "quit"]:
print("\nThank you for using! Goodbye!") print("\nThank you for using! Goodbye!")
break break
@ -238,7 +245,7 @@ def main():
print(output) print(output)
else: else:
# Default single question mode # Default single question mode
question = ("How to install vLLM?") question = "How to install vLLM?"
output = qa_chain.invoke(question) output = qa_chain.invoke(question)
print("-" * 50) print("-" * 50)
print(output) print(output)

View File

@ -35,6 +35,7 @@ Notes:
- Default ports: 8000 (embedding), 8001 (chat) - Default ports: 8000 (embedding), 8001 (chat)
- First run may take time to download models - First run may take time to download models
""" """
import argparse import argparse
from argparse import Namespace from argparse import Namespace
from typing import Any from typing import Any
@ -59,7 +60,7 @@ def init_config(args: Namespace):
"db_path": args.db_path, "db_path": args.db_path,
"chunk_size": args.chunk_size, "chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap, "chunk_overlap": args.chunk_overlap,
"top_k": args.top_k "top_k": args.top_k,
} }
@ -117,52 +118,58 @@ def query_document(index: VectorStoreIndex, question: str, top_k: int):
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
"""Parse command line arguments""" """Parse command line arguments"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="RAG with vLLM and LlamaIndex")
description='RAG with vLLM and LlamaIndex')
# Add command line arguments # Add command line arguments
parser.add_argument( parser.add_argument(
'--url', "--url",
default=("https://docs.vllm.ai/en/latest/getting_started/" default=("https://docs.vllm.ai/en/latest/getting_started/quickstart.html"),
"quickstart.html"), help="URL of the document to process",
help='URL of the document to process') )
parser.add_argument('--embedding-model', parser.add_argument(
default="ssmits/Qwen2-7B-Instruct-embed-base", "--embedding-model",
help='Model name for embeddings') default="ssmits/Qwen2-7B-Instruct-embed-base",
parser.add_argument('--chat-model', help="Model name for embeddings",
default="qwen/Qwen1.5-0.5B-Chat", )
help='Model name for chat') parser.add_argument(
parser.add_argument('--vllm-api-key', "--chat-model", default="qwen/Qwen1.5-0.5B-Chat", help="Model name for chat"
default="EMPTY", )
help='API key for vLLM compatible services') parser.add_argument(
parser.add_argument('--embedding-endpoint', "--vllm-api-key", default="EMPTY", help="API key for vLLM compatible services"
default="http://localhost:8000/v1", )
help='Base URL for embedding service') parser.add_argument(
parser.add_argument('--chat-endpoint', "--embedding-endpoint",
default="http://localhost:8001/v1", default="http://localhost:8000/v1",
help='Base URL for chat service') help="Base URL for embedding service",
parser.add_argument('--db-path', )
default="./milvus_demo.db", parser.add_argument(
help='Path to Milvus database') "--chat-endpoint",
parser.add_argument('-i', default="http://localhost:8001/v1",
'--interactive', help="Base URL for chat service",
action='store_true', )
help='Enable interactive Q&A mode') parser.add_argument(
parser.add_argument('-c', "--db-path", default="./milvus_demo.db", help="Path to Milvus database"
'--chunk-size', )
type=int, parser.add_argument(
default=1000, "-i", "--interactive", action="store_true", help="Enable interactive Q&A mode"
help='Chunk size for document splitting') )
parser.add_argument('-o', parser.add_argument(
'--chunk-overlap', "-c",
type=int, "--chunk-size",
default=200, type=int,
help='Chunk overlap for document splitting') default=1000,
parser.add_argument('-k', help="Chunk size for document splitting",
'--top-k', )
type=int, parser.add_argument(
default=3, "-o",
help='Number of top results to retrieve') "--chunk-overlap",
type=int,
default=200,
help="Chunk overlap for document splitting",
)
parser.add_argument(
"-k", "--top-k", type=int, default=3, help="Number of top results to retrieve"
)
return parser return parser
@ -193,7 +200,7 @@ def main():
question = input("\nEnter your question: ") question = input("\nEnter your question: ")
# Check for exit command # Check for exit command
if question.lower() in ['quit', 'exit', 'q']: if question.lower() in ["quit", "exit", "q"]:
print("Exiting interactive mode...") print("Exiting interactive mode...")
break break

View File

@ -26,6 +26,7 @@ Usage:
streamlit run streamlit_openai_chatbot_webserver.py \ streamlit run streamlit_openai_chatbot_webserver.py \
--logger.level=debug --logger.level=debug
""" """
import os import os
from datetime import datetime from datetime import datetime
@ -33,8 +34,8 @@ import streamlit as st
from openai import OpenAI from openai import OpenAI
# Get command line arguments from environment variables # Get command line arguments from environment variables
openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY") openai_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1") openai_api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
# Initialize session states for managing chat sessions # Initialize session states for managing chat sessions
if "sessions" not in st.session_state: if "sessions" not in st.session_state:
@ -81,9 +82,9 @@ def get_llm_response(messages, model):
Streaming response object or error message string Streaming response object or error message string
""" """
try: try:
response = client.chat.completions.create(model=model, response = client.chat.completions.create(
messages=messages, model=model, messages=messages, stream=True
stream=True) )
return response return response
except Exception as e: except Exception as e:
st.error(f"Error details: {str(e)}") st.error(f"Error details: {str(e)}")
@ -92,8 +93,9 @@ def get_llm_response(messages, model):
# Sidebar - API Settings first # Sidebar - API Settings first
st.sidebar.title("API Settings") st.sidebar.title("API Settings")
new_api_base = st.sidebar.text_input("API Base URL:", new_api_base = st.sidebar.text_input(
value=st.session_state.api_base_url) "API Base URL:", value=st.session_state.api_base_url
)
if new_api_base != st.session_state.api_base_url: if new_api_base != st.session_state.api_base_url:
st.session_state.api_base_url = new_api_base st.session_state.api_base_url = new_api_base
st.rerun() st.rerun()
@ -109,16 +111,20 @@ if st.sidebar.button("New Session"):
for session_id in sorted(st.session_state.sessions.keys(), reverse=True): for session_id in sorted(st.session_state.sessions.keys(), reverse=True):
# Mark the active session with a pinned button # Mark the active session with a pinned button
if session_id == st.session_state.active_session: if session_id == st.session_state.active_session:
st.sidebar.button(f"📍 {session_id}", st.sidebar.button(
key=session_id, f"📍 {session_id}",
type="primary", key=session_id,
on_click=switch_to_chat_session, type="primary",
args=(session_id, )) on_click=switch_to_chat_session,
args=(session_id,),
)
else: else:
st.sidebar.button(f"Session {session_id}", st.sidebar.button(
key=session_id, f"Session {session_id}",
on_click=switch_to_chat_session, key=session_id,
args=(session_id, )) on_click=switch_to_chat_session,
args=(session_id,),
)
# Main interface # Main interface
st.title("vLLM Chat Assistant") st.title("vLLM Chat Assistant")
@ -145,18 +151,18 @@ for message in st.session_state.messages:
if prompt := st.chat_input("Type your message here..."): if prompt := st.chat_input("Type your message here..."):
# Save user message to session # Save user message to session
st.session_state.messages.append({"role": "user", "content": prompt}) st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.sessions[ st.session_state.sessions[st.session_state.current_session] = (
st.session_state.current_session] = st.session_state.messages st.session_state.messages
)
# Display user message # Display user message
with st.chat_message("user"): with st.chat_message("user"):
st.write(prompt) st.write(prompt)
# Prepare messages for llm # Prepare messages for llm
messages_for_llm = [{ messages_for_llm = [
"role": m["role"], {"role": m["role"], "content": m["content"]} for m in st.session_state.messages
"content": m["content"] ]
} for m in st.session_state.messages]
# Generate and display llm response # Generate and display llm response
with st.chat_message("assistant"): with st.chat_message("assistant"):
@ -179,7 +185,4 @@ if prompt := st.chat_input("Type your message here..."):
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
# Save llm response to session history # Save llm response to session history
st.session_state.messages.append({ st.session_state.messages.append({"role": "assistant", "content": full_response})
"role": "assistant",
"content": full_response
})

View File

@ -16,10 +16,10 @@ def get_first_model(client: OpenAI) -> str:
f"{client.base_url} with API key {client.api_key}. Check\n" f"{client.base_url} with API key {client.api_key}. Check\n"
"1. the server is running\n" "1. the server is running\n"
"2. the server URL is correct\n" "2. the server URL is correct\n"
"3. the API key is correct") from e "3. the API key is correct"
) from e
if len(models.data) == 0: if len(models.data) == 0:
raise RuntimeError( raise RuntimeError(f"No models found on the vLLM server at {client.base_url}")
f"No models found on the vLLM server at {client.base_url}")
return models.data[0].id return models.data[0].id

View File

@ -20,6 +20,7 @@ Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to: Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html https://docs.lmcache.ai/getting_started/installation.html
""" """
import argparse import argparse
import contextlib import contextlib
import os import os
@ -49,8 +50,7 @@ def setup_environment_variables(vllm_version: str):
@contextlib.contextmanager @contextlib.contextmanager
def build_llm_with_lmcache(lmcache_connector: str, model: str, def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
vllm_version: str):
ktc = KVTransferConfig( ktc = KVTransferConfig(
kv_connector=lmcache_connector, kv_connector=lmcache_connector,
kv_role="kv_both", kv_role="kv_both",
@ -97,18 +97,19 @@ def print_output(
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}") print(f"Generated text: {generated_text!r}")
print(f"Generation took {time.time() - start:.2f} seconds, " print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
f"{req_str} request done.")
print("-" * 50) print("-" * 50)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-v", parser.add_argument(
"--version", "-v",
choices=["v0", "v1"], "--version",
default="v1", choices=["v0", "v1"],
help="Specify vLLM version (default: v1)") default="v1",
help="Specify vLLM version (default: v1)",
)
return parser.parse_args() return parser.parse_args()
@ -125,7 +126,6 @@ def main():
setup_environment_variables(args.version) setup_environment_variables(args.version)
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
# This example script runs two requests with a shared prefix. # This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts # Define the shared prompt and specific prompts
shared_prompt = "Hello, how are you?" * 1000 shared_prompt = "Hello, how are you?" * 1000
@ -136,9 +136,7 @@ def main():
shared_prompt + "Tell me a very long story", shared_prompt + "Tell me a very long story",
] ]
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
top_p=0.95,
max_tokens=10)
# Print the first output # Print the first output
print_output(llm, first_prompt, sampling_params, "first") print_output(llm, first_prompt, sampling_params, "first")

View File

@ -10,6 +10,7 @@ vLLM prefill node -> LMCache server -> vLLM decode node.
Note that `pip install lmcache` is needed to run this example. Note that `pip install lmcache` is needed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache. Learn more about LMCache in https://github.com/LMCache/LMCache.
""" """
import os import os
import subprocess import subprocess
import time import time
@ -49,19 +50,23 @@ def run_prefill(prefill_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig(kv_connector="LMCacheConnector", ktc = KVTransferConfig(
kv_role="kv_producer", kv_connector="LMCacheConnector",
kv_rank=0, kv_role="kv_producer",
kv_parallel_size=2) kv_rank=0,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory. # memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
#llm.generate(prompts, sampling_params) # llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
@ -79,17 +84,21 @@ def run_decode(prefill_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnector", ktc = KVTransferConfig(
kv_role="kv_consumer", kv_connector="LMCacheConnector",
kv_rank=1, kv_role="kv_consumer",
kv_parallel_size=2) kv_rank=1,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory. # of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for prefill node to finish...") print("Waiting for prefill node to finish...")
prefill_done.wait() prefill_done.wait()
@ -105,10 +114,9 @@ def run_decode(prefill_done, prompts, timeout=1):
def run_lmcache_server(port): def run_lmcache_server(port):
server_proc = subprocess.Popen([ server_proc = subprocess.Popen(
"python", "-m", "lmcache.experimental.server", "localhost", ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
str(port) )
])
return server_proc return server_proc

View File

@ -17,13 +17,17 @@ async def lifespan(app: FastAPI):
Lifespan context manager to handle startup and shutdown events. Lifespan context manager to handle startup and shutdown events.
""" """
# Startup: Initialize clients # Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' prefiller_base_url = (
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
)
decoder_base_url = (
f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
)
app.state.prefill_client = httpx.AsyncClient(timeout=None, app.state.prefill_client = httpx.AsyncClient(
base_url=prefiller_base_url) timeout=None, base_url=prefiller_base_url
app.state.decode_client = httpx.AsyncClient(timeout=None, )
base_url=decoder_base_url) app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url)
yield yield
@ -37,7 +41,6 @@ app = FastAPI(lifespan=lifespan)
class StatsCalculator: class StatsCalculator:
def __init__(self): def __init__(self):
self._stats = [] self._stats = []
self._last_log_time = time.time() self._last_log_time = time.time()
@ -51,13 +54,18 @@ class StatsCalculator:
def _log_stats(self): def _log_stats(self):
# Print average, median, and 99th percentile # Print average, median, and 99th percentile
np_arr = np.array(self._stats) np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \ output_str = (
"\nPrefill node TTFT stats:" + \ f"\nNum requests: {len(self._stats)}"
f"\n - Average (ms): {np.mean(np_arr)}" + \ + "\nPrefill node TTFT stats:"
f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - Average (ms): {np.mean(np_arr)}"
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + f"\n - Median (ms): {np.median(np_arr)}"
print("===============================", output_str, + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
"===============================") )
print(
"===============================",
output_str,
"===============================",
)
stats_calculator = StatsCalculator() stats_calculator = StatsCalculator()
@ -82,15 +90,16 @@ app.state.prefill_client = None
app.state.decode_client = None app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, async def send_request_to_service(
req_data: dict): client: httpx.AsyncClient, endpoint: str, req_data: dict
):
""" """
Send a request to a service using a persistent client. Send a request to a service using a persistent client.
""" """
req_data = req_data.copy() req_data = req_data.copy()
req_data['max_tokens'] = 1 req_data["max_tokens"] = 1
if 'max_completion_tokens' in req_data: if "max_completion_tokens" in req_data:
req_data['max_completion_tokens'] = 1 req_data["max_completion_tokens"] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers) response = await client.post(endpoint, json=req_data, headers=headers)
@ -98,14 +107,16 @@ async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
return response return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str, async def stream_service_response(
req_data: dict): client: httpx.AsyncClient, endpoint: str, req_data: dict
):
""" """
Asynchronously stream the response from a service using a persistent client. Asynchronously stream the response from a service using a persistent client.
""" """
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data, async with client.stream(
headers=headers) as response: "POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status() response.raise_for_status()
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
yield chunk yield chunk
@ -121,28 +132,28 @@ async def handle_completions(request: Request):
req_data = await request.json() req_data = await request.json()
# Send request to prefill service, ignore the response # Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions", await send_request_to_service(
req_data) app.state.prefill_client, "/completions", req_data
)
et = time.time() et = time.time()
stats_calculator.add(et - st) stats_calculator.add(et - st)
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client, async for chunk in stream_service_response(
"/completions", app.state.decode_client, "/completions", req_data
req_data): ):
yield chunk yield chunk
return StreamingResponse(generate_stream(), return StreamingResponse(generate_stream(), media_type="text/event-stream")
media_type="text/event-stream")
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server" print("Error occurred in disagg prefill proxy server - completions endpoint")
" - completions endpoint")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
@ -158,36 +169,39 @@ async def handle_chat_completions(request: Request):
req_data = await request.json() req_data = await request.json()
# Send request to prefill service, ignore the response # Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, await send_request_to_service(
"/chat/completions", req_data) app.state.prefill_client, "/chat/completions", req_data
)
et = time.time() et = time.time()
stats_calculator.add(et - st) stats_calculator.add(et - st)
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client, async for chunk in stream_service_response(
"/chat/completions", app.state.decode_client, "/chat/completions", req_data
req_data): ):
yield chunk yield chunk
return StreamingResponse(generate_stream(), return StreamingResponse(generate_stream(), media_type="text/event-stream")
media_type="text/event-stream")
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server " print(
" - chat completions endpoint") "Error occurred in disagg prefill proxy server - chat completions endpoint"
)
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
if __name__ == '__main__': if __name__ == "__main__":
global global_args global global_args
global_args = parse_args() global_args = parse_args()
import uvicorn import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port) uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -10,6 +10,7 @@ KV cache is transferred in the following manner:
Note that lmcache needs to be installed to run this example. Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache. Learn more about LMCache in https://github.com/LMCache/LMCache.
""" """
import os import os
import subprocess import subprocess
import time import time
@ -49,15 +50,16 @@ def run_store(store_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory. # memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
@ -76,15 +78,16 @@ def run_retrieve(store_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory. # of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(
kv_transfer_config=ktc, model="mistralai/Mistral-7B-Instruct-v0.2",
max_model_len=8000, kv_transfer_config=ktc,
gpu_memory_utilization=0.8, max_model_len=8000,
enforce_eager=True) gpu_memory_utilization=0.8,
enforce_eager=True,
)
print("Waiting for KV cache store to finish...") print("Waiting for KV cache store to finish...")
store_done.wait() store_done.wait()
@ -100,10 +103,9 @@ def run_retrieve(store_done, prompts, timeout=1):
def run_lmcache_server(port): def run_lmcache_server(port):
server_proc = subprocess.Popen([ server_proc = subprocess.Popen(
"python", "-m", "lmcache.experimental.server", "localhost", ["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
str(port) )
])
return server_proc return server_proc

View File

@ -10,8 +10,11 @@ from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerArgs, TensorizerConfig, tensorize_lora_adapter, TensorizerArgs,
tensorize_vllm_model) TensorizerConfig,
tensorize_lora_adapter,
tensorize_vllm_model,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring

54
examples/pyproject.toml Normal file
View File

@ -0,0 +1,54 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true

View File

@ -57,6 +57,7 @@ ignore_patterns = [
".buildkite/**", ".buildkite/**",
"benchmarks/**", "benchmarks/**",
"build/**", "build/**",
"examples/**",
] ]
[tool.ruff] [tool.ruff]
@ -144,6 +145,7 @@ skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora
skip_glob = [ skip_glob = [
".buildkite/*", ".buildkite/*",
"benchmarks/*", "benchmarks/*",
"examples/*",
] ]
use_parentheses = true use_parentheses = true
skip_gitignore = true skip_gitignore = true