Merge branch 'main' into wentao-optimize-startup-log-2

This commit is contained in:
yewentao256 2025-10-15 12:48:30 -07:00
commit 4652ff1e02
30 changed files with 1115 additions and 433 deletions

3
.github/CODEOWNERS vendored
View File

@ -5,9 +5,7 @@
/vllm/attention @LucasWilkinson
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/model_executor/layers/fused_moe @mgoin
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn
@ -26,7 +24,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
# vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
/vllm/v1/attention @LucasWilkinson
/vllm/v1/attention/backends/flashinfer.py @mgoin
/vllm/v1/attention/backends/triton_attn.py @tdoublep

View File

@ -2,6 +2,7 @@
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

View File

@ -10,6 +10,7 @@
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
const scalar_t* input_row = input + blockIdx.x * input_stride;
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;

View File

@ -352,6 +352,16 @@ Supported models:
Flags: `--tool-call-parser qwen3_xml`
### Olmo 3 Models (`olmo3`)
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
Supported models:
* TODO (will be updated after Olmo 3 release)
Flags: `--tool-call-parser olmo3`
### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=null, "
"passed_test=true, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming_json_literals",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming_json_literals",
),
pytest.param(
True,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
model_output_deltas = [
"<function_calls>get_weather(city='San",
" Francisco', metric='celsius')\n"
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@ -12,7 +12,7 @@ from vllm.entrypoints.openai.api_server import (
from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators
MODEL_PATH = "zai-org/chatglm3-6b"
LORA_RANK = 64

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncIterator
import pytest
from vllm.utils.async_utils import merge_async_iterators
async def _mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"iterator {idx} cancelled")
@pytest.mark.asyncio
async def test_merge_async_iterators():
iterators = [_mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators)
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")
task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e

View File

@ -2,14 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import asyncio
import hashlib
import json
import os
import pickle
import socket
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import patch
@ -37,7 +35,6 @@ from vllm.utils import (
make_zmq_path,
make_zmq_socket,
memory_profiling,
merge_async_iterators,
sha256,
split_host_port,
split_zmq_path,
@ -48,39 +45,6 @@ from vllm.utils import (
from ..utils import create_new_process_for_each_test
@pytest.mark.asyncio
async def test_merge_async_iterators():
async def mock_async_iterator(idx: int):
try:
while True:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators)
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
async for idx, output in generator:
print(f"idx: {idx}, output: {output}")
task = asyncio.create_task(stream_output(merged_iterator))
await asyncio.sleep(0.5)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
for iterator in iterators:
try:
await asyncio.wait_for(anext(iterator), 1)
except StopAsyncIteration:
# All iterators should be cancelled and print this message.
print("Iterator was cancelled normally")
except (Exception, asyncio.CancelledError) as e:
raise AssertionError() from e
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PORT", "5678")

View File

@ -34,7 +34,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators
def run_vllm(

View File

@ -5,6 +5,8 @@ import hashlib
from functools import cached_property
from typing import Any, Literal, cast
from packaging.version import parse
from pydantic import field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm import version
@ -79,25 +81,43 @@ class ObservabilityConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if (
self.collect_detailed_traces is not None
and len(self.collect_detailed_traces) == 1
and "," in self.collect_detailed_traces[0]
):
self._parse_collect_detailed_traces()
@field_validator("show_hidden_metrics_for_version")
@classmethod
def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None:
if value is not None:
# Raises an exception if the string is not a valid version.
parse(value)
return value
from vllm.tracing import is_otel_available, otel_import_error_traceback
@field_validator("otlp_traces_endpoint")
@classmethod
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
if value is not None:
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
if not is_otel_available():
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}"
)
return value
@field_validator("collect_detailed_traces")
@classmethod
def _validate_collect_detailed_traces(
cls, value: list[DetailedTraceModules] | None
) -> list[DetailedTraceModules] | None:
"""Handle the legacy case where users might provide a comma-separated
string instead of a list of strings."""
if value is not None and len(value) == 1 and "," in value[0]:
value = cast(list[DetailedTraceModules], value[0].split(","))
return value
@model_validator(mode="after")
def _validate_tracing_config(self):
if self.collect_detailed_traces and not self.otlp_traces_endpoint:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}"
"collect_detailed_traces requires `--otlp-traces-endpoint` to be set."
)
def _parse_collect_detailed_traces(self):
assert isinstance(self.collect_detailed_traces, list)
self.collect_detailed_traces = cast(
list[DetailedTraceModules], self.collect_detailed_traces[0].split(",")
)
return self

View File

@ -34,7 +34,8 @@ from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import as_list, merge_async_iterators
from vllm.utils import as_list
from vllm.utils.async_utils import merge_async_iterators
logger = init_logger(__name__)

View File

@ -40,6 +40,7 @@ from vllm.outputs import (
)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
from vllm.utils.async_utils import merge_async_iterators
logger = init_logger(__name__)
@ -387,8 +388,6 @@ class EmbeddingMixin(OpenAIServing):
)
generators.append(generator)
from vllm.utils import merge_async_iterators
ctx.result_generator = merge_async_iterators(*generators)
return None

View File

@ -90,14 +90,13 @@ from vllm.tracing import (
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (
from vllm.utils import is_list_of, random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
is_list_of,
make_async,
merge_async_iterators,
random_uuid,
)
from vllm.utils.func import make_async
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)

View File

@ -36,7 +36,7 @@ from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils import merge_async_iterators
from vllm.utils.async_utils import merge_async_iterators
logger = init_logger(__name__)

View File

@ -37,8 +37,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import merge_async_iterators
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)

View File

@ -18,6 +18,7 @@ from .llama_tool_parser import Llama3JsonToolParser
from .longcat_tool_parser import LongcatFlashToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .olmo3_tool_parser import Olmo3PythonicToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
@ -45,6 +46,7 @@ __all__ = [
"DeepSeekV31ToolParser",
"Ernie45ToolParser",
"xLAMToolParser",
"Olmo3PythonicToolParser",
"MinimaxToolParser",
"KimiK2ToolParser",
"HunyuanA13BToolParser",

View File

@ -0,0 +1,368 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("olmo3")
class Olmo3PythonicToolParser(ToolParser):
"""
Tool call parser for Olmo 3 models that produce tool calls as
newline-separated pythonic strings.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
Code copied from pythonic_tool_parser.py and updated to handle
- newline separated pythonic tool calls.
- argument values being null/true/false instead of Pythonic literals.
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
original_model_output = model_output
# Remove xml tags.
match = re.search(
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
)
if match:
model_output = match.group(1).strip()
# Make the newline separated function calls into a list.
model_output = ", ".join(
[line.strip() for line in model_output.splitlines() if line.strip()]
)
model_output = f"[{model_output}]"
is_tool_call_pattern = False
try:
is_tool_call_pattern = (
self.TOOL_CALL_REGEX.match(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
)
is not None
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts
):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=original_model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# All function calls start with the <function_calls> tag.
# But since this is streaming, we may have seen only part of the tag.
if not current_text.startswith("<"):
return DeltaMessage(content=delta_text)
try:
# Remove xml tags.
if current_text.startswith("<function_calls>"):
current_text = current_text[len("<function_calls>") :]
if current_text.endswith("</function_calls>"):
current_text = current_text[: -len("</function_calls>")]
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
# Make the newline separated function calls into a list.
valid_text = ", ".join(
[line.strip() for line in valid_text.splitlines() if line.strip()]
)
valid_text = f"[{valid_text}]"
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a sequence of newline-separated calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = added_text[:-1] if not new_call_complete else ""
if not new_call_complete and added_text[-1] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
if delta is not None:
tool_deltas.append(delta)
if (
delta.function is not None
and delta.function.arguments is not None
):
self.streamed_args_for_tool[index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content="")
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
# The model may return function calls where the values are null/true/false
# because the system prompt has API description in json.
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
if val.id == "null":
return None
elif val.id == "true":
return True
elif val.id == "false":
return False
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@ -17,7 +17,7 @@ from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)

View File

@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
# Supported options:
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
@ -1492,7 +1497,6 @@ def compute_hash() -> str:
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_DEEP_GEMM_E8M0",
"VLLM_USE_TRTLLM_FP4_GEMM",
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
"VLLM_USE_FLASHINFER_MOE_FP16",
"VLLM_USE_FLASHINFER_MOE_FP8",
@ -1524,6 +1528,7 @@ def compute_hash() -> str:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_NVFP4_GEMM_BACKEND",
"VLLM_USE_FBGEMM",
]
for key in environment_variables_to_hash:

View File

@ -17,7 +17,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.worker.worker_base import WorkerBase

View File

@ -20,12 +20,11 @@ from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (
_run_task_with_lock,
get_distributed_init_method,
get_ip,
get_open_port,
)
from vllm.utils.func import make_async
from vllm.utils.async_utils import make_async
from vllm.v1.outputs import SamplerOutput
if ray is not None:
@ -748,3 +747,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
# Assume that the Ray workers are healthy.
# TODO: check the health of the Ray workers
return
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)

View File

@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
swizzle_blockscale,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
logger.info_once("Using flashinfer-trtllm for FP4")
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
logger.info_once("Using flashinfer-cutlass for FP4")
else:
self.backend = "cutlass"
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.group_size = 16
@classmethod
@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer.alpha,
output_dtype,
)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
use_mx=False,
).to(output_dtype)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:

View File

@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
self.backend = "flashinfer-trtllm"
elif has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
else:
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif is_fp4_marlin_supported():
self.backend = "marlin"
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
def create_weights(
self,
layer: torch.nn.Module,
@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.alpha,
output_dtype,
)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:
@ -1542,23 +1546,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
del layer.w2_input_scale_quant
else:
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
assert layer.w13_weight_scale.shape[2] % 16 == 0, (
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
"Weight Blockscale must be represented as FP8-E4M3"
)
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
layer.w13_weight_scale = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
assert layer.w2_weight_scale.shape[2] % 16 == 0, (
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
"Weight Blockscale must be represented as FP8-E4M3"
)
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_weight_scale = Parameter(
w2_blockscale_swizzled, requires_grad=False

View File

@ -32,7 +32,7 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)
and current_platform.has_device_capability(100)
)

View File

@ -581,7 +581,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
# file not changed, use cached _ModelInfo properties
return _ModelInfo(**mi_dict["modelinfo"])
except Exception:
logger.exception(
logger.debug(
("Cached model info for class %s.%s error. "),
self.module_name,
self.class_name,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import datetime
import enum
@ -38,10 +37,8 @@ from argparse import (
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (
AsyncGenerator,
Callable,
Collection,
Generator,
@ -51,7 +48,6 @@ from collections.abc import (
Mapping,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
@ -82,7 +78,6 @@ import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import Never, TypeIs, assert_never
import vllm.envs as envs
@ -223,278 +218,12 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
class AsyncMicrobatchTokenizer:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self._loop = asyncio.get_running_loop()
self._queues: dict[
tuple,
asyncio.Queue[
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
],
] = {}
self._batcher_tasks: list[asyncio.Task] = []
# Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
result_future: asyncio.Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((token_ids, result_future))
return await result_future
# === Internal helpers ===
def _get_queue(
self, loop: asyncio.AbstractEventLoop, key: tuple
) -> asyncio.Queue[
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue = self._queues.get(key)
if queue is None:
self._queues[key] = queue = asyncio.Queue()
if key[0] == "encode":
can_batch = key[1] != "other"
coro = self._batch_encode_loop(queue, can_batch)
else:
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
coro = self._batch_decode_loop(queue)
self._batcher_tasks.append(loop.create_task(coro))
return queue
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
"""Batch incoming encode requests for efficiency."""
while True:
prompt, kwargs, result_future = await queue.get()
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
prompt, kwargs, result_future = await asyncio.wait_for(
queue.get(), timeout
)
prompts.append(prompt)
result_futures.append(result_future)
if not can_batch:
kwargs_list.append(kwargs)
except asyncio.TimeoutError:
break
try:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await self._loop.run_in_executor(
self._executor, batch_encode_fn
)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(BatchEncoding(data))
else:
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
]
results = await self._loop.run_in_executor(
self._executor, encode_fn
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
async def _batch_decode_loop(self, queue: asyncio.Queue):
"""Batch incoming decode requests for efficiency."""
while True:
token_ids, result_future = await queue.get()
token_ids_list = [token_ids]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(token_ids_list) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
token_ids, result_future = await asyncio.wait_for(
queue.get(), timeout
)
token_ids_list.append(token_ids)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
try:
# Perform a single batched decode call for all requests
results = await self._loop.run_in_executor(
self._executor, self.tokenizer.batch_decode, token_ids_list
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def _queue_key(self, op: str, kwargs: dict) -> tuple:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if op == "decode":
return ("decode",)
add_special_tokens = kwargs.get("add_special_tokens", True)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
if not truncation:
return "encode", add_special_tokens, False, None
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None and max_length == model_max):
return "encode", add_special_tokens, True, "model_max"
return "encode", "other"
def __del__(self):
if (
(tasks := getattr(self, "_batcher_tasks", None))
and (loop := getattr(self, "_loop", None))
and not loop.is_closed()
):
def cancel_tasks():
for task in tasks:
task.cancel()
loop.call_soon_threadsafe(cancel_tasks)
def cancel_task_threadsafe(task: Task):
if task and not task.done():
run_in_loop(task.get_loop(), task.cancel)
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
for sock in sockets:
if sock is not None:
sock.close(linger=0)
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
if in_loop(loop):
function(*args)
elif not loop.is_closed():
loop.call_soon_threadsafe(function, *args)
def in_loop(event_loop: AbstractEventLoop) -> bool:
try:
return asyncio.get_running_loop() == event_loop
except RuntimeError:
return False
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
if len(iterators) == 1:
# Fast-path single iterator case.
async for item in iterators[0]:
yield 0, item
return
loop = asyncio.get_running_loop()
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
try:
while awaits:
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[loop.create_task(anext(it))] = pair
yield i, item
except StopAsyncIteration:
pass
finally:
# Cancel any remaining iterators
for f, (_, it) in awaits.items():
with contextlib.suppress(BaseException):
f.cancel()
await it.aclose()
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
"""Collect all items from an async generator into a list."""
items = []
async for item in iterator:
items.append(item)
return items
def get_ip() -> str:
host_ip = envs.VLLM_HOST_IP
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
@ -1803,12 +1532,6 @@ class FlexibleArgumentParser(ArgumentParser):
return processed_args
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.

299
vllm/utils/async_utils.py Normal file
View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Contains helpers related to asynchronous code."""
import asyncio
import contextlib
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import AsyncGenerator, Awaitable, Callable
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from typing import TypeVar
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
class AsyncMicrobatchTokenizer:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self._loop = asyncio.get_running_loop()
self._queues: dict[
tuple,
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
] = {}
self._batcher_tasks: list[Task] = []
# Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
result_future: Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
result_future: Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((token_ids, result_future))
return await result_future
# === Internal helpers ===
def _get_queue(
self, loop: asyncio.AbstractEventLoop, key: tuple
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue = self._queues.get(key)
if queue is None:
self._queues[key] = queue = asyncio.Queue()
if key[0] == "encode":
can_batch = key[1] != "other"
coro = self._batch_encode_loop(queue, can_batch)
else:
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
coro = self._batch_decode_loop(queue)
self._batcher_tasks.append(loop.create_task(coro))
return queue
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
"""Batch incoming encode requests for efficiency."""
while True:
prompt, kwargs, result_future = await queue.get()
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
prompt, kwargs, result_future = await asyncio.wait_for(
queue.get(), timeout
)
prompts.append(prompt)
result_futures.append(result_future)
if not can_batch:
kwargs_list.append(kwargs)
except asyncio.TimeoutError:
break
try:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await self._loop.run_in_executor(
self._executor, batch_encode_fn
)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(BatchEncoding(data))
else:
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
]
results = await self._loop.run_in_executor(
self._executor, encode_fn
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
async def _batch_decode_loop(self, queue: asyncio.Queue):
"""Batch incoming decode requests for efficiency."""
while True:
token_ids, result_future = await queue.get()
token_ids_list = [token_ids]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(token_ids_list) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
token_ids, result_future = await asyncio.wait_for(
queue.get(), timeout
)
token_ids_list.append(token_ids)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
try:
# Perform a single batched decode call for all requests
results = await self._loop.run_in_executor(
self._executor, self.tokenizer.batch_decode, token_ids_list
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def _queue_key(self, op: str, kwargs: dict) -> tuple:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if op == "decode":
return ("decode",)
add_special_tokens = kwargs.get("add_special_tokens", True)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
if not truncation:
return "encode", add_special_tokens, False, None
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None and max_length == model_max):
return "encode", add_special_tokens, True, "model_max"
return "encode", "other"
def __del__(self):
if (
(tasks := getattr(self, "_batcher_tasks", None))
and (loop := getattr(self, "_loop", None))
and not loop.is_closed()
):
def cancel_tasks():
for task in tasks:
task.cancel()
loop.call_soon_threadsafe(cancel_tasks)
def cancel_task_threadsafe(task: Task):
if task and not task.done():
run_in_loop(task.get_loop(), task.cancel)
def make_async(
func: Callable[P, T],
executor: Executor | None = None,
) -> Callable[P, Awaitable[T]]:
"""
Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
if in_loop(loop):
function(*args)
elif not loop.is_closed():
loop.call_soon_threadsafe(function, *args)
def in_loop(event_loop: AbstractEventLoop) -> bool:
try:
return asyncio.get_running_loop() == event_loop
except RuntimeError:
return False
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
if len(iterators) == 1:
# Fast-path single iterator case.
async for item in iterators[0]:
yield 0, item
return
loop = asyncio.get_running_loop()
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
try:
while awaits:
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[loop.create_task(anext(it))] = pair
yield i, item
except StopAsyncIteration:
pass
finally:
# Cancel any remaining iterators
for f, (_, it) in awaits.items():
with contextlib.suppress(BaseException):
f.cancel()
await it.aclose()
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
"""Collect all items from an async generator into a list."""
items = []
async for item in iterator:
items.append(item)
return items

View File

@ -6,12 +6,10 @@ Contains helpers that are applied to functions.
This is similar in concept to the `functools` module.
"""
import asyncio
import concurrent.futures
import inspect
import threading
import warnings
from collections.abc import Awaitable, Callable, Mapping
from collections.abc import Callable, Mapping
from functools import lru_cache, partial, wraps
from typing import Any, TypeVar
@ -32,26 +30,6 @@ def identity(value: T, **kwargs) -> T:
return value
def make_async(
func: Callable[P, T],
executor: concurrent.futures.Executor | None = None,
) -> Callable[P, Awaitable[T]]:
"""
Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if wrapper.has_run: # type: ignore[attr-defined]

View File

@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
from vllm.utils import Device, as_list, cdiv
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.func import deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient

View File

@ -27,9 +27,9 @@ from vllm.utils import (
close_sockets,
get_open_port,
get_open_zmq_inproc_path,
in_loop,
make_zmq_socket,
)
from vllm.utils.async_utils import in_loop
from vllm.v1.engine import (
EngineCoreOutputs,
EngineCoreRequest,