mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 06:47:13 +08:00
Merge branch 'main' into wentao-optimize-startup-log-2
This commit is contained in:
commit
4652ff1e02
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -5,9 +5,7 @@
|
||||
/vllm/attention @LucasWilkinson
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/model_executor/layers/fused_moe @mgoin
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
@ -26,7 +24,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
float x = static_cast<float>(vec.val[i]);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
auto scalar_op = [&variance](const scalar_t& val) {
|
||||
float x = static_cast<float>(val);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
vllm::vectorize_read_with_alignment<VEC_SIZE>(
|
||||
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
float x = static_cast<float>(vec.val[i]);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
auto scalar_op = [&variance](const scalar_t& val) {
|
||||
float x = static_cast<float>(val);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
vllm::vectorize_read_with_alignment<VEC_SIZE>(
|
||||
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
|
||||
@ -352,6 +352,16 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser qwen3_xml`
|
||||
|
||||
### Olmo 3 Models (`olmo3`)
|
||||
|
||||
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
|
||||
|
||||
Supported models:
|
||||
|
||||
* TODO (will be updated after Olmo 3 release)
|
||||
|
||||
Flags: `--tool-call-parser olmo3`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
||||
243
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
243
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
Normal file
@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=null, "
|
||||
"passed_test=true, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
model_output_deltas = [
|
||||
"<function_calls>get_weather(city='San",
|
||||
" Francisco', metric='celsius')\n"
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
@ -12,7 +12,7 @@ from vllm.entrypoints.openai.api_server import (
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
MODEL_PATH = "zai-org/chatglm3-6b"
|
||||
LORA_RANK = 64
|
||||
|
||||
42
tests/utils_/test_async_utils.py
Normal file
42
tests/utils_/test_async_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
async def _mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
iterators = [_mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
print(f"idx: {idx}, output: {output}")
|
||||
|
||||
task = asyncio.create_task(stream_output(merged_iterator))
|
||||
await asyncio.sleep(0.5)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
raise AssertionError() from e
|
||||
@ -2,14 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -37,7 +35,6 @@ from vllm.utils import (
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
memory_profiling,
|
||||
merge_async_iterators,
|
||||
sha256,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
@ -48,39 +45,6 @@ from vllm.utils import (
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
async def mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
print(f"idx: {idx}, output: {output}")
|
||||
|
||||
task = asyncio.create_task(stream_output(merged_iterator))
|
||||
await asyncio.sleep(0.5)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
raise AssertionError() from e
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
|
||||
@ -34,7 +34,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
|
||||
@ -5,6 +5,8 @@ import hashlib
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
@ -79,25 +81,43 @@ class ObservabilityConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.collect_detailed_traces is not None
|
||||
and len(self.collect_detailed_traces) == 1
|
||||
and "," in self.collect_detailed_traces[0]
|
||||
):
|
||||
self._parse_collect_detailed_traces()
|
||||
@field_validator("show_hidden_metrics_for_version")
|
||||
@classmethod
|
||||
def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
# Raises an exception if the string is not a valid version.
|
||||
parse(value)
|
||||
return value
|
||||
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
@field_validator("otlp_traces_endpoint")
|
||||
@classmethod
|
||||
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
if not is_otel_available():
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}"
|
||||
)
|
||||
return value
|
||||
|
||||
@field_validator("collect_detailed_traces")
|
||||
@classmethod
|
||||
def _validate_collect_detailed_traces(
|
||||
cls, value: list[DetailedTraceModules] | None
|
||||
) -> list[DetailedTraceModules] | None:
|
||||
"""Handle the legacy case where users might provide a comma-separated
|
||||
string instead of a list of strings."""
|
||||
if value is not None and len(value) == 1 and "," in value[0]:
|
||||
value = cast(list[DetailedTraceModules], value[0].split(","))
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_tracing_config(self):
|
||||
if self.collect_detailed_traces and not self.otlp_traces_endpoint:
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}"
|
||||
"collect_detailed_traces requires `--otlp-traces-endpoint` to be set."
|
||||
)
|
||||
|
||||
def _parse_collect_detailed_traces(self):
|
||||
assert isinstance(self.collect_detailed_traces, list)
|
||||
self.collect_detailed_traces = cast(
|
||||
list[DetailedTraceModules], self.collect_detailed_traces[0].split(",")
|
||||
)
|
||||
return self
|
||||
|
||||
@ -34,7 +34,8 @@ from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import as_list, merge_async_iterators
|
||||
from vllm.utils import as_list
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ from vllm.outputs import (
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -387,8 +388,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
@ -90,14 +90,13 @@ from vllm.tracing import (
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (
|
||||
from vllm.utils import is_list_of, random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
is_list_of,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
random_uuid,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -37,8 +37,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .longcat_tool_parser import LongcatFlashToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .olmo3_tool_parser import Olmo3PythonicToolParser
|
||||
from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"DeepSeekV31ToolParser",
|
||||
"Ernie45ToolParser",
|
||||
"xLAMToolParser",
|
||||
"Olmo3PythonicToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
|
||||
368
vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py
Normal file
368
vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py
Normal file
@ -0,0 +1,368 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("olmo3")
|
||||
class Olmo3PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Olmo 3 models that produce tool calls as
|
||||
newline-separated pythonic strings.
|
||||
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
|
||||
Code copied from pythonic_tool_parser.py and updated to handle
|
||||
- newline separated pythonic tool calls.
|
||||
- argument values being null/true/false instead of Pythonic literals.
|
||||
"""
|
||||
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
original_model_output = model_output
|
||||
# Remove xml tags.
|
||||
match = re.search(
|
||||
r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL
|
||||
)
|
||||
if match:
|
||||
model_output = match.group(1).strip()
|
||||
# Make the newline separated function calls into a list.
|
||||
model_output = ", ".join(
|
||||
[line.strip() for line in model_output.splitlines() if line.strip()]
|
||||
)
|
||||
model_output = f"[{model_output}]"
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = (
|
||||
self.TOOL_CALL_REGEX.match(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
)
|
||||
is not None
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=original_model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# All function calls start with the <function_calls> tag.
|
||||
# But since this is streaming, we may have seen only part of the tag.
|
||||
if not current_text.startswith("<"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# Remove xml tags.
|
||||
if current_text.startswith("<function_calls>"):
|
||||
current_text = current_text[len("<function_calls>") :]
|
||||
if current_text.endswith("</function_calls>"):
|
||||
current_text = current_text[: -len("</function_calls>")]
|
||||
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
# Make the newline separated function calls into a list.
|
||||
valid_text = ", ".join(
|
||||
[line.strip() for line in valid_text.splitlines() if line.strip()]
|
||||
)
|
||||
valid_text = f"[{valid_text}]"
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a sequence of newline-separated calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = added_text[:-1] if not new_call_complete else ""
|
||||
if not new_call_complete and added_text[-1] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (
|
||||
delta.function is not None
|
||||
and delta.function.arguments is not None
|
||||
):
|
||||
self.streamed_args_for_tool[index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining its final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content="")
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
# The model may return function calls where the values are null/true/false
|
||||
# because the system prompt has API description in json.
|
||||
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
|
||||
if val.id == "null":
|
||||
return None
|
||||
elif val.id == "true":
|
||||
return True
|
||||
elif val.id == "false":
|
||||
return False
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
@ -17,7 +17,7 @@ from vllm.inputs.data import TextPrompt as EngineTextPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
17
vllm/envs.py
17
vllm/envs.py
@ -191,6 +191,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: str | None = None
|
||||
VLLM_NVFP4_GEMM_BACKEND: str | None = None
|
||||
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
|
||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
@ -1292,11 +1293,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# If set, it means we pre-downloaded cubin files and flashinfer will
|
||||
# read the cubin files directly.
|
||||
"VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
|
||||
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
||||
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
||||
# vllm cutlass GEMM, marlin GEMM.
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))
|
||||
# Supported options:
|
||||
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
|
||||
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
|
||||
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
|
||||
# - <none>: automatically pick an available backend
|
||||
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
None,
|
||||
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
|
||||
),
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
@ -1492,7 +1497,6 @@ def compute_hash() -> str:
|
||||
"VLLM_DISABLED_KERNELS",
|
||||
"VLLM_USE_DEEP_GEMM",
|
||||
"VLLM_USE_DEEP_GEMM_E8M0",
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||
@ -1524,6 +1528,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
"VLLM_USE_FBGEMM",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@ -20,12 +20,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (
|
||||
_run_task_with_lock,
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if ray is not None:
|
||||
@ -748,3 +747,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
# Assume that the Ray workers are healthy.
|
||||
# TODO: check the health of the Ray workers
|
||||
return
|
||||
|
||||
|
||||
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@ -29,10 +32,12 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
logger.info_once("Using flashinfer-trtllm for FP4")
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
@ -42,12 +47,17 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
logger.info_once("Using FGBEMM-GPU-GENAI for FP4")
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
logger.info_once("Using flashinfer-cutlass for FP4")
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@ -184,10 +194,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
@ -198,6 +207,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
@ -926,22 +926,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
else:
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above."
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1109,11 +1113,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
@ -1542,23 +1546,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
del layer.w2_input_scale_quant
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
assert layer.w13_weight_scale.shape[2] % 16 == 0, (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16"
|
||||
)
|
||||
assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Blockscale must be represented as FP8-E4M3"
|
||||
)
|
||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
w13_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
assert layer.w2_weight_scale.shape[2] % 16 == 0, (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16"
|
||||
)
|
||||
assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, (
|
||||
"Weight Blockscale must be represented as FP8-E4M3"
|
||||
)
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
|
||||
@ -32,7 +32,7 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -581,7 +581,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
# file not changed, use cached _ModelInfo properties
|
||||
return _ModelInfo(**mi_dict["modelinfo"])
|
||||
except Exception:
|
||||
logger.exception(
|
||||
logger.debug(
|
||||
("Cached model info for class %s.%s error. "),
|
||||
self.module_name,
|
||||
self.class_name,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
@ -38,10 +37,8 @@ from argparse import (
|
||||
RawDescriptionHelpFormatter,
|
||||
_ArgumentGroup,
|
||||
)
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Collection,
|
||||
Generator,
|
||||
@ -51,7 +48,6 @@ from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
@ -82,7 +78,6 @@ import zmq.asyncio
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import Never, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -223,278 +218,12 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[
|
||||
tuple,
|
||||
asyncio.Queue[
|
||||
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
|
||||
],
|
||||
] = {}
|
||||
self._batcher_tasks: list[asyncio.Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[
|
||||
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
|
||||
]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, batch_encode_fn
|
||||
)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn
|
||||
)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||
)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode",)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return "encode", add_special_tokens, False, None
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None and max_length == model_max):
|
||||
return "encode", add_special_tokens, True, "model_max"
|
||||
|
||||
return "encode", "other"
|
||||
|
||||
def __del__(self):
|
||||
if (
|
||||
(tasks := getattr(self, "_batcher_tasks", None))
|
||||
and (loop := getattr(self, "_loop", None))
|
||||
and not loop.is_closed()
|
||||
):
|
||||
|
||||
def cancel_tasks():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(cancel_tasks)
|
||||
|
||||
|
||||
def cancel_task_threadsafe(task: Task):
|
||||
if task and not task.done():
|
||||
run_in_loop(task.get_loop(), task.cancel)
|
||||
|
||||
|
||||
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||
for sock in sockets:
|
||||
if sock is not None:
|
||||
sock.close(linger=0)
|
||||
|
||||
|
||||
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||
if in_loop(loop):
|
||||
function(*args)
|
||||
elif not loop.is_closed():
|
||||
loop.call_soon_threadsafe(function, *args)
|
||||
|
||||
|
||||
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
try:
|
||||
return asyncio.get_running_loop() == event_loop
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
if len(iterators) == 1:
|
||||
# Fast-path single iterator case.
|
||||
async for item in iterators[0]:
|
||||
yield 0, item
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||
try:
|
||||
while awaits:
|
||||
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[loop.create_task(anext(it))] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||
@ -1803,12 +1532,6 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
return processed_args
|
||||
|
||||
|
||||
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
|
||||
299
vllm/utils/async_utils.py
Normal file
299
vllm/utils/async_utils.py
Normal file
@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Contains helpers related to asynchronous code."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import TypeVar
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[
|
||||
tuple,
|
||||
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
|
||||
] = {}
|
||||
self._batcher_tasks: list[Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, batch_encode_fn
|
||||
)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn
|
||||
)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||
)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode",)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return "encode", add_special_tokens, False, None
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None and max_length == model_max):
|
||||
return "encode", add_special_tokens, True, "model_max"
|
||||
|
||||
return "encode", "other"
|
||||
|
||||
def __del__(self):
|
||||
if (
|
||||
(tasks := getattr(self, "_batcher_tasks", None))
|
||||
and (loop := getattr(self, "_loop", None))
|
||||
and not loop.is_closed()
|
||||
):
|
||||
|
||||
def cancel_tasks():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(cancel_tasks)
|
||||
|
||||
|
||||
def cancel_task_threadsafe(task: Task):
|
||||
if task and not task.done():
|
||||
run_in_loop(task.get_loop(), task.cancel)
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||
if in_loop(loop):
|
||||
function(*args)
|
||||
elif not loop.is_closed():
|
||||
loop.call_soon_threadsafe(function, *args)
|
||||
|
||||
|
||||
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
try:
|
||||
return asyncio.get_running_loop() == event_loop
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
if len(iterators) == 1:
|
||||
# Fast-path single iterator case.
|
||||
async for item in iterators[0]:
|
||||
yield 0, item
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||
try:
|
||||
while awaits:
|
||||
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[loop.create_task(anext(it))] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
||||
@ -6,12 +6,10 @@ Contains helpers that are applied to functions.
|
||||
This is similar in concept to the `functools` module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
@ -32,26 +30,6 @@ def identity(value: T, **kwargs) -> T:
|
||||
return value
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: concurrent.futures.Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
|
||||
@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
|
||||
from vllm.utils import Device, as_list, cdiv
|
||||
from vllm.utils.async_utils import cancel_task_threadsafe
|
||||
from vllm.utils.func import deprecate_kwargs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
|
||||
@ -27,9 +27,9 @@ from vllm.utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
in_loop,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.utils.async_utils import in_loop
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user