Merge remote-tracking branch 'origin/main' into mlm-full-lora-support

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-24 02:00:02 +00:00
commit f9ca68541c
102 changed files with 3325 additions and 1455 deletions

View File

@ -104,7 +104,6 @@ def run_benchmark_with_batch_invariant(
random.seed(seed)
# Set environment variables
os.environ["VLLM_ATTENTION_BACKEND"] = backend
if batch_invariant:
os.environ["VLLM_BATCH_INVARIANT"] = "1"
else:
@ -140,6 +139,7 @@ def run_benchmark_with_batch_invariant(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=tp_size,
attention_config={"backend": backend},
enable_prefix_caching=False,
)
init_time = time.perf_counter() - start_init

View File

@ -35,7 +35,7 @@ template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return (x + y - 1) / y * y;
return ((x + y - 1) / y) * y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
@ -61,37 +61,47 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
// Each thread writes 4 uint32_t elements.
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
col += blockDim.x * 4) {
SFout[row * sf_n_int + col] = 0x00;
}
}
int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x;
colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// If we are outside valid rows OR outside valid columns -> Use Zeros
if (rowIdx >= numRows || elem_idx >= numCols) {
memset(&in_vec, 0, sizeof(PackedVec));
} else {
// Valid Region: Load actual data
in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
}
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numKTiles, SFout);
out_pos =
auto out_val =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
// We do NOT write output for padding because the 'out' tensor is not
// padded.
if (rowIdx < numRows && elem_idx < numCols) {
// Same as inOffset because 8 elements are packed into one uint32_t.
out[inOffset] = out_val;
}
}
}
}
@ -134,4 +144,4 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@ -2,7 +2,7 @@ FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \
add-apt-repository -y ppa:kobuk-team/intel-graphics
add-apt-repository -y ppa:kobuk-team/intel-graphics-staging
RUN apt clean && apt-get update -y && \
apt-get install -y --no-install-recommends --fix-missing \
@ -47,6 +47,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir \
-r requirements/xpu.txt
# arctic-inference is built from source which needs torch-xpu properly installed
# used for suffix method speculative decoding
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir arctic-inference==0.1.1
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
COPY . .

View File

@ -2,4 +2,4 @@
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
You can use vLLM with KServe's [Hugging Face serving runtime](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) or via [`LLMInferenceService` that uses llm-d](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).

View File

@ -0,0 +1,5 @@
# llm-d
vLLM can be deployed with [llm-d](https://github.com/llm-d/llm-d), a Kubernetes-native distributed inference serving stack providing well-lit paths for anyone to serve large generative AI models at scale. It helps achieve the fastest "time to state-of-the-art (SOTA) performance" for key OSS models across most hardware accelerators and infrastructure providers.
You can use vLLM with llm-d directly by following [this guide](https://llm-d.ai/docs/guide) or via [KServe's LLMInferenceService](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).

View File

@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
- [Helm](frameworks/helm.md)
- [InftyAI/llmaz](integrations/llmaz.md)
- [llm-d](integrations/llm-d.md)
- [KAITO](integrations/kaito.md)
- [KServe](integrations/kserve.md)
- [Kthena](integrations/kthena.md)

View File

@ -64,7 +64,7 @@ th:not(:first-child) {
| [CP](../configuration/optimization.md#chunked-prefill) | [](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [APC](automatic_prefix_caching.md) | [](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26963) |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [](https://github.com/vllm-project/vllm/issues/26970) |
| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |

View File

@ -490,6 +490,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | |
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | |
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | |
| `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ |
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
@ -543,8 +544,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
@ -562,6 +564,11 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note
`nvidia/llama-nemotron-rerank-1b-v2` require a specific prompt format to work correctly.
Examples : [offline_using_template.py](../../examples/pooling/score/offline_using_template.py) [online_using_template.py](../../examples/pooling/score/online_using_template.py)
!!! note
Load the official original `mxbai-rerank-v2` by using the following command.

View File

@ -669,6 +669,21 @@ You can find the documentation for cross encoder models at [sbert.net](https://w
Code example: [examples/pooling/score/openai_cross_encoder_score.py](../../examples/pooling/score/openai_cross_encoder_score.py)
#### Score Template
Some scoring models require a specific prompt format to work correctly. You can specify a custom score template using the `--chat-template` parameter (see [Chat Template](#chat-template)).
Score templates are supported for **cross-encoder** models only. If you are using an **embedding** model for scoring, vLLM does not apply a score template.
Like chat templates, the score template receives a `messages` list. For scoring, each message has a `role` attribute—either `"query"` or `"document"`. For the usual kind of point-wise cross-encoder, you can expect exactly two messages: one query and one document. To access the query and document content, use Jinja's `selectattr` filter:
- **Query**: `{{ (messages | selectattr("role", "eq", "query") | first).content }}`
- **Document**: `{{ (messages | selectattr("role", "eq", "document") | first).content }}`
This approach is more robust than index-based access (`messages[0]`, `messages[1]`) because it selects messages by their semantic role. It also avoids assumptions about message ordering if additional message types are added to `messages` in the future.
Example template file: [examples/pooling/score/template/nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja)
#### Single inference
You can pass a string to both `text_1` and `text_2`, forming a single sentence pair.

View File

@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from pathlib import Path
from vllm import LLM
model_name = "nvidia/llama-nemotron-rerank-1b-v2"
# Path to template file
template_path = Path(__file__).parent / "template" / "nemotron-rerank.jinja"
chat_template = template_path.read_text()
llm = LLM(model=model_name, runner="pooling", trust_remote_code=True)
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
outputs = llm.score(query, documents, chat_template=chat_template)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)

View File

@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using the rerank API with template.
run:
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
"""
import json
import requests
url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
data = {
"model": "nvidia/llama-nemotron-rerank-1b-v2",
"query": query,
"documents": documents,
}
def main():
response = requests.post(url, headers=headers, json=data)
# Check the response
if response.status_code == 200:
print("Request successful!")
print(json.dumps(response.json(), indent=2))
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,3 @@
question:{{ (messages | selectattr("role", "eq", "query") | first).content }}
passage:{{ (messages | selectattr("role", "eq", "document") | first).content }}

View File

@ -557,7 +557,8 @@ def test_rms_group_quant(
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig(
# Testing properties

View File

@ -77,6 +77,7 @@ def test_dynamic_shapes_compilation(
"evaluate_guards": evaluate_guards,
},
},
max_model_len=1024,
)
output = model.generate(prompt)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
@ -53,37 +52,61 @@ class TestModel(torch.nn.Module):
hidden_size: int,
eps: float,
group_shape: GroupShape,
cuda_force_torch: bool,
use_aiter: bool = False,
cuda_force_torch: bool = False,
use_aiter_quant_op: bool = True,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.use_aiter = use_aiter
self.use_aiter_quant_op = use_aiter_quant_op
self.cuda_force_torch = cuda_force_torch
self.group_shape = group_shape
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
if group_shape.is_per_group():
self.wscale = [
torch.rand(
(hidden_size // group_shape[1], hidden_size // group_shape[1]),
dtype=torch.float32,
)
for _ in range(3)
]
else:
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
static = group_shape == GroupShape.PER_TENSOR
# Setup quantization scale descriptor
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
# Setup scales
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(3)]
# Setup weights
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
]
if not group_shape.is_per_group():
if not group_shape.is_per_group() or use_aiter:
self.w = [self.w[0].t() for _ in range(3)]
# Setup weight scales
if group_shape.is_per_group():
scale_size = (
(hidden_size + 128 - 1) // 128
if use_aiter
else hidden_size // group_shape[1]
)
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
else:
wscale_shape = (1,)
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
# Setup FP8 linear operation
is_per_group = group_shape.is_per_group()
if is_per_group and use_aiter:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=group_shape,
use_aiter_and_is_supported=use_aiter_quant_op,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif is_per_group:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
@ -91,6 +114,13 @@ class TestModel(torch.nn.Module):
use_aiter_and_is_supported=False,
)
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
elif use_aiter:
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
)
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
else:
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
@ -100,7 +130,6 @@ class TestModel(torch.nn.Module):
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.group_shape = group_shape
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
@ -126,19 +155,49 @@ class TestModel(torch.nn.Module):
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_before(self):
if (
self.use_aiter
and self.group_shape.is_per_group()
and current_platform.is_fp8_fnuz()
):
return [rocm_aiter_ops.get_group_quant_op()]
if self.use_aiter and self.group_shape.is_per_group():
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
if self.use_aiter and self.use_aiter_quant_op:
return [rocm_aiter_ops.get_per_token_quant_op()]
if self.use_aiter:
return [QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op:
return [QUANT_OPS[self.quant_key]]
return [torch.ops.aten.reciprocal]
def ops_in_model_after(self):
if self.use_aiter and self.group_shape.is_per_group():
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP,
]
if self.use_aiter:
from vllm.compilation.rocm_aiter_fusion import (
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSNormDynamicQuantPattern,
)
return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP,
]
return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
@ -155,67 +214,45 @@ GROUP_SHAPES = [
]
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
def _run_fusion_test(
model,
fusion_pass,
vllm_config,
dtype,
hidden_size,
num_tokens,
):
"""Helper function for common fusion test logic.
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
Must be called within vllm_config context.
"""
noop_pass = NoOpEliminationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
return backend, backend2
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@ -223,11 +260,8 @@ class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@ -242,23 +276,13 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
group_shape,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
# Skip test for 64-bit group shape when running with cutlass or deepgemm
if group_shape == GroupShape(1, 64) and (
cutlass_block_fp8_supported() or is_deep_gemm_supported()
):
@ -269,6 +293,7 @@ def test_fusion_rmsnorm_quant(
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
@ -279,60 +304,97 @@ def test_fusion_rmsnorm_quant(
),
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
# Setup device before model creation
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = model_class(
fusion_pass = RMSNormQuantFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
use_aiter=False,
cuda_force_torch=cuda_force_torch,
)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
backend, _ = _run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)
backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2
GROUP_SHAPE_QUANT_OPS_MATCHS = [
(GroupShape.PER_TOKEN, True),
(GroupShape.PER_TOKEN, False),
(GroupShape(1, 128), True),
]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
)
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_quant(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
group_shape: GroupShape,
use_aiter_quant_op: bool,
monkeypatch: pytest.MonkeyPatch,
):
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,
eps=eps,
group_shape=group_shape,
use_aiter=True,
use_aiter_quant_op=use_aiter_quant_op,
)
_run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)

View File

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio
from vllm.assets.audio import AudioAsset
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio
from ...utils import RemoteOpenAIServer
@ -53,6 +53,14 @@ def base64_encoded_audio() -> dict[str, str]:
}
@pytest.fixture(scope="session")
def url_encoded_audio() -> dict[str, str]:
return {
audio_url: encode_audio_url(*fetch_audio(audio_url))
for audio_url in TEST_AUDIO_URLS
}
def dummy_messages_from_audio_url(
audio_urls: str | list[str],
content_text: str = "What's happening in this audio?",
@ -149,11 +157,9 @@ async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI,
model_name: str,
audio_url: str,
base64_encoded_audio: dict[str, str],
url_encoded_audio: dict[str, str],
):
messages = dummy_messages_from_audio_url(
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
)
messages = dummy_messages_from_audio_url(url_encoded_audio[audio_url])
# test single completion
chat_completion = await client.chat.completions.create(

View File

@ -7,7 +7,7 @@ import openai
import pytest
import pytest_asyncio
from vllm.multimodal.utils import encode_video_base64, fetch_video
from vllm.multimodal.utils import encode_video_url, fetch_video
from ...utils import RemoteOpenAIServer
@ -48,9 +48,9 @@ async def client(server):
@pytest.fixture(scope="session")
def base64_encoded_video() -> dict[str, str]:
def url_encoded_video() -> dict[str, str]:
return {
video_url: encode_video_base64(fetch_video(video_url)[0])
video_url: encode_video_url(fetch_video(video_url)[0])
for video_url in TEST_VIDEO_URLS
}
@ -175,11 +175,9 @@ async def test_single_chat_session_video_base64encoded(
client: openai.AsyncOpenAI,
model_name: str,
video_url: str,
base64_encoded_video: dict[str, str],
url_encoded_video: dict[str, str],
):
messages = dummy_messages_from_video_url(
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
)
messages = dummy_messages_from_video_url(url_encoded_video[video_url])
# test single completion
chat_completion = await client.chat.completions.create(
@ -223,11 +221,9 @@ async def test_single_chat_session_video_base64encoded_beamsearch(
client: openai.AsyncOpenAI,
model_name: str,
video_url: str,
base64_encoded_video: dict[str, str],
url_encoded_video: dict[str, str],
):
messages = dummy_messages_from_video_url(
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
)
messages = dummy_messages_from_video_url(url_encoded_video[video_url])
chat_completion = await client.chat.completions.create(
model=model_name,

View File

@ -9,7 +9,7 @@ import pytest_asyncio
from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.multimodal.utils import encode_image_url, fetch_image
from ...utils import RemoteOpenAIServer
@ -35,7 +35,7 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
],
[
"The image shows a Venn diagram with three over",
"The image shows a colorful Venn diagram with",
"The image displays a Venn diagram with three over",
],
[
"This image displays a gradient of colors ranging from",
@ -70,11 +70,9 @@ async def client(server):
@pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]:
def url_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_asset: encode_image_base64(
local_asset_server.get_image_asset(image_asset)
)
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
for image_asset in TEST_IMAGE_ASSETS
}
@ -234,11 +232,11 @@ async def test_single_chat_session_image_base64encoded(
model_name: str,
raw_image_url: str,
image_url: str,
base64_encoded_image: dict[str, str],
url_encoded_image: dict[str, str],
):
content_text = "What's in this image?"
messages = dummy_messages_from_image_url(
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}",
url_encoded_image[raw_image_url],
content_text,
)
@ -288,15 +286,13 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI,
model_name: str,
image_idx: int,
base64_encoded_image: dict[str, str],
url_encoded_image: dict[str, str],
):
# NOTE: This test also validates that we pass MM data through beam search
raw_image_url = TEST_IMAGE_ASSETS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
messages = dummy_messages_from_image_url(
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
)
messages = dummy_messages_from_image_url(url_encoded_image[raw_image_url])
chat_completion = await client.chat.completions.create(
model=model_name,

View File

@ -10,7 +10,7 @@ from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.multimodal.utils import fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
MAXIMUM_IMAGES = 2
@ -48,14 +48,6 @@ def server():
yield remote_server
@pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_url: encode_image_base64(local_asset_server.get_image_asset(image_url))
for image_url in TEST_IMAGE_ASSETS
}
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, num_crops=4

View File

@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.score_utils import get_score_prompt
from vllm.inputs import TokensPrompt
from vllm.tokenizers import get_tokenizer
# A cross-encoder model for testing
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
def assert_prompt_tokenization_consistent(
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
):
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
"input_ids"
]
actual_ids = engine_prompt["prompt_token_ids"]
assert actual_ids == expected_ids, (
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
)
@pytest.fixture(scope="module")
def cross_encoder_model_config():
return ModelConfig(
CROSS_ENCODER_MODEL_ID,
runner="pooling",
)
@pytest.fixture(scope="module")
def cross_encoder_tokenizer(cross_encoder_model_config):
return get_tokenizer(
CROSS_ENCODER_MODEL_ID,
trust_remote_code=cross_encoder_model_config.trust_remote_code,
)
@pytest.fixture(scope="module")
def llm_reranker_model_config():
"""Model config for LLM-as-reranker style (no pad token)."""
config = ModelConfig(
CROSS_ENCODER_MODEL_ID,
runner="pooling",
)
# use_pad_token is a property that reads from hf_config,
# so we set it there to override the default (True)
config.hf_config.use_pad_token = False
return config
@pytest.fixture
def tokenization_kwargs():
"""Common tokenization kwargs used across tests."""
return {"add_special_tokens": True, "return_tensors": None}
@pytest.fixture
def mock_model_with_score_template():
"""Mock model class that supports score template and tracks post_process calls."""
class MockModelWithScoreTemplate:
supports_score_template = True
post_process_called: list[TokensPrompt] = []
@staticmethod
def get_score_template(p1: str, p2: str) -> str:
return f"[QUERY]{p1}[SEP][DOC]{p2}"
@staticmethod
def post_process_tokens(prompt: TokensPrompt) -> None:
MockModelWithScoreTemplate.post_process_called.append(prompt)
return MockModelWithScoreTemplate
@pytest.fixture
def mock_model_no_score_template():
"""Mock model class that does not support score template."""
class MockModelNoScoreTemplate:
supports_score_template = False
return MockModelNoScoreTemplate
class TestGetScorePrompt:
"""Tests for the get_score_prompt function."""
def test_tokenization_kwargs_passed_through(
self,
llm_reranker_model_config,
cross_encoder_tokenizer,
):
"""Test that tokenization kwargs are properly passed through."""
data_1 = "Query text"
data_2 = "Document text"
# Test with truncation - custom kwargs for this test
custom_tokenization_kwargs = {
"add_special_tokens": True,
"return_tensors": None,
"truncation": True,
"max_length": 20,
}
full_prompt, engine_prompt = get_score_prompt(
llm_reranker_model_config,
cross_encoder_tokenizer,
custom_tokenization_kwargs,
data_1,
data_2,
)
assert isinstance(full_prompt, str)
assert "prompt_token_ids" in engine_prompt
# With max_length=20 and truncation, should not exceed this
assert len(engine_prompt["prompt_token_ids"]) <= 20
# Since truncation was applied, token_ids should be a prefix of full encoding
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
"input_ids"
]
actual_ids = engine_prompt["prompt_token_ids"]
assert full_ids[: len(actual_ids)] == actual_ids, (
f"Token IDs are not a prefix of full encoding.\n"
f"Full IDs: {full_ids}\n"
f"Actual IDs: {actual_ids}"
)
def test_model_supports_score_template(
self,
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_with_score_template,
):
"""Test when model supports score template (no score_template arg)."""
with patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_with_score_template,
):
full_prompt, engine_prompt = get_score_prompt(
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
"query text",
"document text",
)
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
assert "prompt_token_ids" in engine_prompt
assert len(engine_prompt["prompt_token_ids"]) > 0
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_model_supports_score_template_but_custom_template_provided(
self,
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_with_score_template,
):
"""Test when model supports score template but custom template is provided."""
template = (
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
)
with (
patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_with_score_template,
),
):
full_prompt, engine_prompt = get_score_prompt(
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
"query",
"doc",
score_template=template, # Providing a template
)
assert "prompt_token_ids" in engine_prompt
assert full_prompt == "TEMPLATE_USED query doc"
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_not_using_default_template(
self,
llm_reranker_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_no_score_template,
):
# FIXME: Models implementing SupportsScoreTemplate must use their custom
# template implementation by default to preserve existing functionality.
# Attempting to use tokenizer_config.json templates would most likely break
# these models, as often they just inherit the template from the original LLM.
# CLI --chat-template overrides are still supported.
with (
patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template",
return_value="test querytest doc",
),
):
full_prompt, engine_prompt = get_score_prompt(
llm_reranker_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
"test query",
"test doc",
)
assert full_prompt == "test querytest doc"
assert "prompt_token_ids" in engine_prompt
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_fallback_with_pad_token(
self,
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_no_score_template,
):
"""Test fallback path when ChatTemplateResolutionError
and use_pad_token=True."""
with (
patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
full_prompt, engine_prompt = get_score_prompt(
cross_encoder_model_config, # use_pad_token=True
cross_encoder_tokenizer,
tokenization_kwargs,
"query",
"document",
)
assert "prompt_token_ids" in engine_prompt
# Should have token_type_ids from text_pair encoding
assert "token_type_ids" in engine_prompt
assert "query" in full_prompt
assert "document" in full_prompt
assert full_prompt != "querydocument"
assert (
engine_prompt["prompt_token_ids"]
== cross_encoder_tokenizer(
"query", text_pair="document", add_special_tokens=True
)["input_ids"]
)
# FIXME(?): add_special_tokens=False is needed because in this case
# full_prompt is obtained by decoding the tokenized prompt, which includes
# special tokens and we would get duplicated special tokens otherwise.
# This is inconsistent with other cases.
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer,
full_prompt,
engine_prompt,
add_special_tokens=False,
)
def test_fallback_without_pad_token(
self,
llm_reranker_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_no_score_template,
):
"""Test fallback path when ChatTemplateResolutionError
and use_pad_token=False."""
with (
patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
full_prompt, engine_prompt = get_score_prompt(
llm_reranker_model_config, # use_pad_token=False
cross_encoder_tokenizer,
tokenization_kwargs,
"query",
"document",
)
assert full_prompt == "querydocument"
assert "prompt_token_ids" in engine_prompt
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)
def test_post_process_tokens_called(
self,
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
mock_model_with_score_template,
):
"""Test that post_process_tokens is called on the engine prompt."""
# Reset the call tracker
mock_model_with_score_template.post_process_called.clear()
with (
patch(
"vllm.model_executor.model_loader.get_model_cls",
return_value=mock_model_with_score_template,
),
patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
full_prompt, engine_prompt = get_score_prompt(
cross_encoder_model_config,
cross_encoder_tokenizer,
tokenization_kwargs,
"query",
"doc",
)
# post_process_tokens should have been called once
assert len(mock_model_with_score_template.post_process_called) == 1
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt
)

View File

@ -25,9 +25,9 @@ from vllm.entrypoints.chat_utils import (
)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (
encode_audio_base64,
encode_image_base64,
encode_video_base64,
encode_audio_url,
encode_image_url,
encode_video_url,
)
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
@ -141,22 +141,19 @@ def mistral_model_config():
@pytest.fixture(scope="module")
def image_url():
image = ImageAsset("cherry_blossom")
base64 = encode_image_base64(image.pil_image)
return f"data:image/jpeg;base64,{base64}"
return encode_image_url(image.pil_image)
@pytest.fixture(scope="module")
def video_url():
video = VideoAsset("baby_reading", 1)
base64 = encode_video_base64(video.np_ndarrays)
return f"data:video/jpeg;base64,{base64}"
return encode_video_url(video.np_ndarrays)
@pytest.fixture(scope="module")
def audio_url():
audio = AudioAsset("mary_had_lamb")
base64 = encode_audio_base64(*audio.audio_and_sample_rate)
return f"data:audio/ogg;base64,{base64}"
return encode_audio_url(*audio.audio_and_sample_rate)
def _assert_mm_data_is_image_input(

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
import mteb
import numpy as np
@ -19,6 +20,11 @@ from tests.models.utils import (
get_vllm_extra_kwargs,
)
template_home = (
Path(__file__).parent.parent.parent.parent.parent
/ "examples/pooling/score/template"
)
# Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# results in differences less than 1e-4
@ -102,30 +108,6 @@ class VllmMtebEncoder(mteb.EncoderProtocol):
return sim
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
def __init__(self, vllm_model):
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = self.llm.score(
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
)
scores = np.array(outputs)
return scores
class OpenAIClientMtebEncoder(VllmMtebEncoder):
def __init__(self, model_name: str, client):
self.model_name = model_name
@ -153,6 +135,35 @@ class OpenAIClientMtebEncoder(VllmMtebEncoder):
return embeds
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
def __init__(self, vllm_model):
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
self.chat_template: str | None = getattr(vllm_model, "chat_template", None)
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = self.llm.score(
queries,
corpus,
truncate_prompt_tokens=-1,
use_tqdm=False,
chat_template=self.chat_template,
)
scores = np.array(outputs)
return scores
class ScoreClientMtebEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
@ -387,6 +398,11 @@ def mteb_test_rerank_models(
== model_info.default_pooling_type
)
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
vllm_model.chat_template = chat_template
vllm_main_score = run_mteb_rerank(
vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.utils import (
EmbedModelInfo,
LASTPoolingEmbedModelInfo,
LASTPoolingRerankModelInfo,
RerankModelInfo,
)
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
EMBEDDING_MODELS = [
LASTPoolingEmbedModelInfo(
"nvidia/llama-nemotron-embed-1b-v2",
architecture="LlamaBidirectionalModel",
mteb_score=0.689164662128673,
)
]
RERANK_MODELS = [
LASTPoolingRerankModelInfo(
"nvidia/llama-nemotron-rerank-1b-v2",
architecture="LlamaBidirectionalForSequenceClassification",
chat_template_name="nemotron-rerank.jinja",
mteb_score=0.33994,
),
]
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)

View File

@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

View File

@ -8,7 +8,7 @@ from PIL.Image import Image
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
@ -31,10 +31,7 @@ def test_keye_vl(
question: str,
):
images = [asset.pil_image for asset in image_assets]
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
]
image_urls = [encode_image_url(image) for image in images]
engine_args = EngineArgs(
model=MODEL_NAME,

View File

@ -15,7 +15,7 @@ from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform
@ -178,8 +178,7 @@ def build_dots_ocr_prompt(images, config):
"""Build Dots.OCR specific prompt with OCR instructions."""
# Use only stop_sign image for Dots.OCR
image = images[0] # Already filtered to stop_sign
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
image_url = encode_image_url(image)
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
messages = [
@ -204,9 +203,7 @@ def build_processor_prompt(images, config):
config["model_name"], trust_remote_code=True
)
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
image_urls = [encode_image_url(img) for img in images]
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
@ -225,9 +222,7 @@ def build_processor_prompt(images, config):
def build_ovis_prompt(images, config):
"""Build Ovis2.5 specific prompt with custom format."""
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
]
image_urls = [encode_image_url(img) for img in images]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)

View File

@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets):
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.message.content == "In the first audio clip, you hear a brief"
assert choice.finish_reason == "length"

View File

@ -488,6 +488,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"LlamaBidirectionalModel": _HfExamplesInfo(
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo(
"Alibaba-NLP/gte-modernbert-base", trust_remote_code=True
@ -554,6 +557,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
trust_remote_code=True,
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
),
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
),
"ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base"
),
@ -854,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# disable this temporarily until we support HF format
is_available_online=False,
),
"VoxtralStreamingGeneration": _HfExamplesInfo(
"<place-holder>",
# disable this temporarily until we support HF format
is_available_online=False,
),
# [Encoder-decoder]
"WhisperForConditionalGeneration": _HfExamplesInfo(
"openai/whisper-large-v3-turbo",

View File

@ -399,6 +399,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo):
@dataclass
class RerankModelInfo(ModelInfo):
mteb_score: float | None = None
chat_template_name: str | None = None
@dataclass

View File

@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"complex",
"wrong_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
],
None,
),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
],
)
def test_extract_tool_calls(
@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
)
)
],
"",
"hi{hi",
),
],
)

View File

@ -27,7 +27,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA,
@ -289,7 +289,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend(
backend: AttentionBackendEnum,
kv_cache_spec: MLAAttentionSpec,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
@ -740,7 +740,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = MLAAttentionSpec(
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
@ -748,7 +748,6 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
backend_output = run_attention_backend(

View File

@ -31,7 +31,7 @@ import openai
import requests
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
MAX_OUTPUT_LEN = 256
@ -49,9 +49,7 @@ SAMPLE_PROMPTS_MM: list[dict] = [
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_1)}"
},
"image_url": {"url": encode_image_url(image_1)},
},
{"type": "text", "text": "What's in this image?"},
],
@ -66,9 +64,7 @@ SAMPLE_PROMPTS_MM: list[dict] = [
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_2)}"
},
"image_url": {"url": encode_image_url(image_2)},
},
{
"type": "image_url",

View File

@ -260,7 +260,7 @@ async def test_multi_abort(output_kind: RequestOutputKind):
# Use multi-abort to abort multiple requests at once
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
await engine.abort(abort_request_ids)
await engine.abort(abort_request_ids, internal=False)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
@ -609,7 +609,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind):
await asyncio.sleep(0.5)
# Abort the request
await engine.abort(request_id)
await engine.abort(request_id, internal=False)
# Wait for generation to complete and return final output
final_output = await generated

View File

@ -40,10 +40,16 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "I am Gyoubu Masataka Oniwa"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request() -> EngineCoreRequest:
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=PROMPT_TOKENS,
mm_features=None,
sampling_params=SamplingParams(),

View File

@ -45,6 +45,8 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request(
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
@ -52,8 +54,12 @@ def make_request(
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=prompt_tokens_ids,
mm_features=None,
sampling_params=params,

View File

@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
request_id="test",
external_req_id="test-ext",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,

View File

@ -58,12 +58,12 @@ def test_incremental_detokenization(
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
)
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests.
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -83,6 +83,11 @@ def test_incremental_detokenization(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -438,15 +443,6 @@ def test_logprobs_processor(
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
)
# Make N requests.
request_id_list = [
@ -454,7 +450,8 @@ def test_logprobs_processor(
]
requests = [
EngineCoreRequest(
request_id=request_id_list[idx],
request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -476,6 +473,17 @@ def test_logprobs_processor(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -621,19 +629,12 @@ def test_stop_token(
]
prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
)
# Make request.
request_id = "request-0"
request = EngineCoreRequest(
request_id=request_id,
external_req_id=request_id + "-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=eos_token_id,
@ -655,6 +656,16 @@ def test_stop_token(
pooling_params=None,
)
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
request_ids=[request.request_id],
)
# Add request to the detokenizer.
output_processor.add_request(request, prompt_string)
@ -720,13 +731,6 @@ def test_stop_string(
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
)
# Make N requests.
request_id_list = [
@ -734,7 +738,8 @@ def test_stop_string(
]
requests = [
EngineCoreRequest(
request_id=request_id_list[idx],
request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -756,6 +761,15 @@ def test_stop_string(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -813,9 +827,12 @@ def test_stop_string(
for idx, (ref_gen_str, stop_str) in enumerate(
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
):
# Request should be aborted.
# Request should be aborted (check internal ID in abort list).
internal_request_id = f"request-{idx}-int"
assert internal_request_id in aborted
# Use external ID for collecting outputs
request_id = f"request-{idx}"
assert request_id in aborted
# Collected values that were generated.
gen_str = gen_strings[request_id]
@ -848,13 +865,13 @@ def test_stop_string(
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Make N requests.
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
external_req_id=f"request-{idx}-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests except one to the OutputProcessor.
num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]:
@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats
)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Create LoRA requests
@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
lora_assignments = [lora1, lora2, None]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests to the OutputProcessor
for request in requests:
output_processor.add_request(request, None)
@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-0 as finished (it uses lora-1)
# Find and mark request-0-int as finished (it uses lora-1)
for output in outputs.outputs:
if output.request_id == "request-0":
if output.request_id == "request-0-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-1 as finished (it uses lora-2)
# Find and mark request-1-int as finished (it uses lora-2)
for output in outputs.outputs:
if output.request_id == "request-1":
if output.request_id == "request-1-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-2 as finished (it has no LoRA)
# Find and mark request-2-int as finished (it has no LoRA)
for output in outputs.outputs:
if output.request_id == "request-2":
if output.request_id == "request-2-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1107,7 +1134,9 @@ async def test_request_output_collector():
for idx in range(NUM_REQS)
]
collector = RequestOutputCollector(RequestOutputKind.DELTA)
collector = RequestOutputCollector(
RequestOutputKind.DELTA, request_id="my-request-id-int"
)
# CASE 1: Put then get.
outputs = make_outputs()
@ -1163,7 +1192,9 @@ async def test_request_output_collector():
@pytest.mark.asyncio
async def test_cumulative_output_collector_n():
"""Test collector correctly handles multiple outputs by index."""
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
collector = RequestOutputCollector(
RequestOutputKind.CUMULATIVE, request_id="my-request-id-int"
)
outputs = [
RequestOutput(
request_id="my-request-id",
@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n():
@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
@pytest.mark.parametrize("abort_by", ["internal", "external"])
def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
external_req_id=f"external-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors):
output_kind = request.sampling_params.output_kind
else:
output_kind = request.pooling_params.output_kind
queue = RequestOutputCollector(output_kind=output_kind)
queue = RequestOutputCollector(
output_kind=output_kind, request_id=request.request_id
)
output_processor.add_request(request, None, queue=queue)
for request in requests:
output_processor.abort_requests([request.request_id])
if abort_by == "internal":
output_processor.abort_requests([request.request_id], internal=True)
else:
output_processor.abort_requests([request.external_req_id], internal=False)

View File

@ -4,11 +4,12 @@
from vllm import SamplingParams
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.parallel_sampling import ParentRequest
def test_parent_request_to_output_stream() -> None:
parent_request = ParentRequest("parent_id", SamplingParams(n=2))
parent_request = ParentRequest(make_request(SamplingParams(n=2)))
parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput(
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
@ -17,51 +18,31 @@ def test_parent_request_to_output_stream() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
)
# Request not finished
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
# output_1 finished
output_1.finish_reason = "ended"
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
# Finished output_1 had already returned, DO NOT returned again
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_0 finished
output_0.finish_reason = "ended"
assert ("parent_id", [output_0], True) == parent_request.get_outputs(
"child_id_0", output_0
)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0)
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
# Finished output_0 had already returned, DO NOT returned again
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
assert parent_request.get_outputs("child_id_0", output_0) == ([], True)
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
def test_parent_request_to_output_final_only() -> None:
parent_request = ParentRequest(
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY)
make_request(SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY))
)
parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput(
@ -71,33 +52,33 @@ def test_parent_request_to_output_final_only() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
)
# Request not finished, return nothing
assert parent_request.get_outputs("child_id_0", output_0) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_1 finished, but outputs won't be returned until all child requests finished
output_1.finish_reason = "ended"
assert parent_request.get_outputs("child_id_0", output_0) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_0 finished, as all child requests finished, the output would be returned
output_0.finish_reason = "ended"
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_1", output_1
)
def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id="parent_id",
external_req_id="ext_parent_id",
prompt_token_ids=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -5,6 +5,7 @@ import pytest
import torch.cuda
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
@ -14,6 +15,11 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
"""Test that preprocessing errors are handled gracefully."""
if current_platform.is_rocm():
pytest.skip(
"Skipped on ROCm: this test only works with 'fork', but ROCm uses 'spawn'."
)
assert not torch.cuda.is_initialized(), (
"fork needs to be used for the engine "
"core process and this isn't possible if cuda is already initialized"

View File

@ -6,6 +6,7 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import input_processor as input_processor_mod
from vllm.v1.engine.input_processor import InputProcessor
@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
)
captured: dict[str, object] = {}
captured: dict[str, MultiModalUUIDDict] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
@ -196,7 +197,16 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
# Expect request-id-based overrides are passed through
assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}
mm_uuids = captured["mm_uuids"]
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][0].endswith("-0")
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][1].endswith("-1")
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
"video"
][0].endswith("-0")

View File

@ -343,6 +343,7 @@ class MockEngineCore:
eos_token_id: int | None = None,
stop_token_ids: list[int] | None = None,
ignore_eos: bool = False,
request_ids: list[str] | None = None,
) -> None:
self.num_requests = len(tokens_list)
self.tokens_list = tokens_list
@ -355,6 +356,11 @@ class MockEngineCore:
self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos
self.request_ids = (
request_ids
if request_ids is not None
else [f"request-{i}" for i in range(self.num_requests)]
)
def get_outputs(self) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs
@ -386,7 +392,7 @@ class MockEngineCore:
prompt_logprobs = None
new_token_id = token_ids[token_idx]
output = EngineCoreOutput(
request_id=f"request-{req_idx}",
request_id=self.request_ids[req_idx],
new_token_ids=[new_token_id],
new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs,

View File

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
# Use a small vision model for testing
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
@ -52,9 +52,9 @@ async def client(image_server):
@pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]:
def url_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_url: encode_image_base64(local_asset_server.get_image_asset(image_url))
image_url: encode_image_url(local_asset_server.get_image_asset(image_url))
for image_url in TEST_IMAGE_ASSETS
}
@ -95,7 +95,7 @@ async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI,
model_name: str,
raw_image_url: str,
base64_encoded_image: dict[str, str],
url_encoded_image: dict[str, str],
):
content_text = "What's in this image?"
messages = [
@ -104,7 +104,7 @@ async def test_single_chat_session_image_base64encoded(
"content": [
{
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", # noqa: E501
"image_url": url_encoded_image[raw_image_url],
"detail": "auto",
},
{"type": "input_text", "text": content_text},

View File

@ -9,7 +9,7 @@ from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
@ -74,7 +74,7 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
placeholders = [
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
"image_url": {"url": encode_image_url(image_pil)},
}
for image_pil in image_urls
]

View File

@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
@ -1265,6 +1268,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
run_test_and_cleanup()
class RequestIdMapper:
"""Helper class to map external request IDs to internal request IDs."""
def __init__(self, output_processor: OutputProcessor):
self.req_id_mapping: dict[str, str] = {}
self.original_add_request = output_processor.add_request
output_processor.add_request = self._add_request
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
self.req_id_mapping[request.external_req_id] = request.request_id
return self.original_add_request(request, *args, **kwargs)
def __call__(self, external_req_id: str) -> str:
return self.req_id_mapping[external_req_id]
def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic."""
remote_prefill_opts = {
@ -1286,24 +1305,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
0
].req_to_blocks
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
def req_id(outputs: list[RequestOutput]) -> str:
assert len(outputs) == 1
return id_mapper(outputs[0].request_id)
padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
req0_id = req_id(
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
)
# Request finished but not freed
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
# Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
assert "0" in req_to_blocks
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
req1_id = req_id(
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
)
assert req0_id in req_to_blocks
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
# Wait for timeout and trigger another scheduler loop
time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
assert req0_id not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()

View File

@ -4,7 +4,7 @@
import openai
import pytest
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
@ -12,11 +12,9 @@ from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]:
def url_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_asset: encode_image_base64(
local_asset_server.get_image_asset(image_asset)
)
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
for image_asset in TEST_IMAGE_ASSETS
}
@ -24,19 +22,16 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]:
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str]):
async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
pytest.skip("Skip this test until it's fixed.")
def whats_in_this_image_msg(b64):
def whats_in_this_image_msg(url):
return [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
{"type": "image_url", "image_url": {"url": url}},
],
}
]
@ -63,14 +58,14 @@ async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str
# Other requests now should be much faster
for image_url in TEST_IMAGE_ASSETS:
image_base64 = base64_encoded_image[image_url]
chat_completion_from_base64 = await client.chat.completions.create(
image_url = url_encoded_image[image_url]
chat_completion_from_url = await client.chat.completions.create(
model=model_name,
messages=whats_in_this_image_msg(image_base64),
messages=whats_in_this_image_msg(image_url),
max_completion_tokens=24,
temperature=0.0,
)
result = chat_completion_from_base64
result = chat_completion_from_url
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"

View File

@ -4,6 +4,7 @@ import functools
from collections.abc import Callable
import torch
from torch._ops import OpOverload
import vllm.envs as envs
from vllm.platforms import current_platform
@ -433,16 +434,16 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
from aiter import rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
out = torch.empty_like(x)
rmsnorm2d_fwd_with_add(
output, # output
out, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out
return out, residual_out
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
@ -451,7 +452,84 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
residual_out = torch.empty_like(residual)
out = torch.empty_like(x)
return out, residual_out
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
residual_out = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant(
out,
x,
residual,
residual_out,
y_scale,
weight,
epsilon,
use_model_sensitive_rmsnorm=0,
)
return out, residual_out, y_scale
def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
residual_out = torch.empty_like(x)
return out, residual_out, y_scale
def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
rocm_aiter.rmsnorm2d_fwd_with_dynamicquant(
out, x, y_scale, weight, epsilon, use_model_sensitive_rmsnorm=0
)
return out, y_scale
def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
return out, y_scale
def _rocm_aiter_per_tensor_quant_impl(
@ -527,7 +605,11 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
return (
x_quant,
res,
x_quant_scales,
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
@ -541,8 +623,8 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
@ -901,6 +983,20 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fused_dynamic_quant",
op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fused_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fused_add_dynamic_quant",
op_func=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fused_add_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
@ -936,13 +1032,54 @@ class rocm_aiter_ops:
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
def get_rmsnorm_fused_add_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
@staticmethod
def get_rmsnorm_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rms_norm.default
@staticmethod
def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default
@staticmethod
def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_fused_dynamic_quant.default
@staticmethod
def get_rmsnorm_group_fused_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
@staticmethod
def get_rmsnorm_group_add_fused_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
@staticmethod
def get_per_token_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_per_token_quant.default
@staticmethod
def get_group_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_group_fp8_quant.default
@staticmethod
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def rms_norm2d_with_add(
x: torch.Tensor,
@ -954,12 +1091,6 @@ class rocm_aiter_ops:
x, residual, weight, variance_epsilon
)
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def gemm_a8w8(
A: torch.Tensor,

View File

@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens=cu_seqlens,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def _forward_fa(
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
fa_version=self._fa_version,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def forward_native(

View File

@ -6,11 +6,13 @@ import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
_normalize_quant_group_shape,
kFp8Dynamic64Sym,
@ -150,26 +152,50 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
RMS_OP,
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
@ -189,12 +215,23 @@ class MatcherRMSNorm(MatcherCustomOp):
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self.match_rocm_aiter = match_rocm_aiter
self._rmsnorm_op = RMS_ADD_OP
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
@ -202,14 +239,27 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op(
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight, residual)
_, result, residual = auto_functionalized(
RMS_ADD_OP,
self._rmsnorm_op,
input=input,
residual=residual,
weight=weight,
@ -236,22 +286,46 @@ class MatcherQuantFP8(MatcherCustomOp):
enabled: bool | None = None,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
assert not quant_key.scale.group_shape.is_per_tensor(), (
"ROCm aiter fusion pass does not support per tensor quantization"
)
if quant_key.scale.group_shape.is_per_token():
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
else:
assert quant_key.scale.group_shape.col == 128, (
"ROCm aiter fusion pass currently supports "
"quantization operation with group_size 128"
)
if current_platform.is_fp8_fnuz():
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
else:
self.QUANT_OP = (
torch.ops.vllm.triton_per_token_group_quant_fp8.default
)
else:
assert quant_key in QUANT_OPS, (
f"unsupported quantization scheme {quant_key}"
)
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
@ -259,11 +333,29 @@ class MatcherQuantFP8(MatcherCustomOp):
use_ue8m0=is_e8m0,
)
def forward_rocm_aiter(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP(
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col)
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, scale)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)

View File

@ -16,7 +16,7 @@ from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterRMSNormFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
@ -117,7 +117,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
self.passes += [
RocmAiterRMSNormFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():

View File

@ -9,60 +9,195 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .fusion import (
FusedRMSQuantKey,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
class AiterRMSNormQuantPattern:
def __init__(
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
match_rocm_aiter=match_aiter_quant,
)
class AiterRMSFp8GroupQuantPattern:
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
result = self.FUSED_OP(
x=input,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual_out, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
result = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
quant_dtype=self.quant_dtype,
)
return result[0], result[1], result[2]
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
at = self.FUSED_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
@ -71,49 +206,52 @@ class AiterRMSFp8GroupQuantPattern:
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class AiterFusedAddRMSFp8GroupQuantPattern:
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
def __init__(
self,
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
return result, residual_out, scale
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
at = self.FUSED_OP(
x=input,
residual=residual,
weight=weight,
@ -124,18 +262,15 @@ class AiterFusedAddRMSFp8GroupQuantPattern:
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@ -144,20 +279,33 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
AiterRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, GroupShape(1, 128)
).register(self.patterns)
for match_aiter_quant in [True, False]:
# Fuse aiter rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
# dynamic per-token fp8 quant
AiterFusedAddRMSNormDynamicQuantPattern(
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@ -169,6 +317,8 @@ class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
@ -181,6 +331,8 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
@ -196,7 +348,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
@ -216,6 +368,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
@ -224,7 +381,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)

View File

@ -11,7 +11,6 @@ import torch
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
@ -29,6 +28,7 @@ from vllm.transformers_utils.config import (
get_pooling_config,
get_sentence_transformer_tokenizer_config,
is_encoder_decoder,
is_rope_parameters_nested,
try_get_dense_modules,
try_get_generation_config,
try_get_safetensors_metadata,
@ -1094,11 +1094,10 @@ class ModelConfig:
# The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"):
if hasattr(self.hf_text_config, target_attr):
return getattr(self.hf_text_config, target_attr)
return self.get_hidden_size()
names = ("projection_dim", "projection_size")
return getattr_iter(
self.hf_text_config, names, default_factory=self.get_hidden_size
)
@property
def is_deepseek_mla(self) -> bool:
@ -1231,14 +1230,12 @@ class ModelConfig:
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU."""
@ -1542,6 +1539,10 @@ class ModelConfig:
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property
def requires_raw_input_tokens(self) -> bool:
return self._model_info.requires_raw_input_tokens
@property
def is_cross_encoder(self) -> bool:
return (
@ -2125,9 +2126,7 @@ def _get_and_verify_max_len(
# In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
# To simplify the verification, we convert it to dict[str, TypedDict].
rope_parameters = getattr(hf_config, "rope_parameters", None)
if rope_parameters and not set(rope_parameters.keys()).issubset(
ALLOWED_LAYER_TYPES
):
if rope_parameters and not is_rope_parameters_nested(rope_parameters):
rope_parameters = {"": rope_parameters}
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE

View File

@ -9,7 +9,7 @@ import inspect
import json
import pathlib
import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
object: object,
names: Iterable[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
@ -96,7 +100,7 @@ def getattr_iter(
names[0],
)
return getattr(object, name)
return default
return default_factory() if default_factory is not None else default
def contains_object_print(text: str) -> bool:

View File

@ -67,6 +67,15 @@ else:
logger = init_logger(__name__)
class ChatTemplateResolutionError(ValueError):
"""Raised when chat template resolution fails.
This is a subclass of ValueError for backward compatibility with
existing exception handlers.
"""
MODALITY_PLACEHOLDERS_MAP = {
"image": "<##IMAGE##>",
"audio": "<##AUDIO##>",
@ -1814,7 +1823,7 @@ def apply_hf_chat_template(
)
if hf_chat_template is None:
raise ValueError(
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."

View File

@ -1280,6 +1280,7 @@ class LLM:
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
score_template: str | None = None,
) -> list[ScoringRequestOutput]:
model_config = self.model_config
@ -1313,6 +1314,7 @@ class LLM:
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
score_template=score_template,
)
if token_type_ids := engine_prompt.pop("token_type_ids", None):
@ -1347,6 +1349,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
chat_template: str | None = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or
`<multi-modal data, multi-modal data pair>`.
@ -1379,6 +1382,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
chat_template: The chat template to use for the scoring. If None, we
use the model's default chat template.
Returns:
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
@ -1406,6 +1411,11 @@ class LLM:
):
raise ValueError("Score API is only enabled for num_labels == 1.")
if not model_config.is_cross_encoder and chat_template is not None:
raise ValueError(
"chat_template is only supported for cross-encoder models."
)
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
@ -1475,6 +1485,7 @@ class LLM:
use_tqdm,
pooling_params,
lora_request,
score_template=chat_template,
)
else:
return self._embedding_score(
@ -1610,7 +1621,7 @@ class LLM:
added_request_ids.append(request_id)
except Exception as e:
if added_request_ids:
self.llm_engine.abort_request(added_request_ids)
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e
def _validate_mm_data_and_uuids(
@ -1720,7 +1731,7 @@ class LLM:
priority=priority,
prompt_text=prompt_text,
)
return request_id
return engine_request.request_id
def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True

View File

@ -909,6 +909,16 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: Request, exc: RequestValidationError):
from vllm.entrypoints.openai.protocol import VLLMValidationError
param = None
for error in exc.errors():
if "ctx" in error and "error" in error["ctx"]:
ctx_error = error["ctx"]["error"]
if isinstance(ctx_error, VLLMValidationError):
param = ctx_error.parameter
break
exc_str = str(exc)
errors_str = str(exc.errors())
@ -922,6 +932,7 @@ def build_app(args: Namespace) -> FastAPI:
message=message,
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST,
param=param,
)
)
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@ -1145,6 +1156,7 @@ async def init_app_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if ("embed" in supported_tasks or "score" in supported_tasks)

View File

@ -131,6 +131,36 @@ class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
@ -466,7 +496,9 @@ class ResponsesRequest(OpenAIBaseModel):
@model_validator(mode="before")
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise ValueError("prompt template is not supported")
raise VLLMValidationError(
"prompt template is not supported", parameter="prompt"
)
return data
@model_validator(mode="before")
@ -850,7 +882,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.")
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@ -859,19 +894,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`."
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0 and top_logprobs != -1:
raise ValueError("`top_logprobs` must be a positive value or -1.")
raise VLLMValidationError(
"`top_logprobs` must be a positive value or -1.",
parameter="top_logprobs",
value=top_logprobs,
)
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
raise VLLMValidationError(
"when using `top_logprobs`, `logprobs` must be set to true.",
parameter="top_logprobs",
)
return data
@ -1285,9 +1330,10 @@ class CompletionRequest(OpenAIBaseModel):
for k in ("json", "regex", "choice")
)
if count > 1:
raise ValueError(
raise VLLMValidationError(
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice')."
"outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
)
return data
@ -1296,14 +1342,23 @@ class CompletionRequest(OpenAIBaseModel):
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`."
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
raise VLLMValidationError(
"`logprobs` must be a positive value.",
parameter="logprobs",
value=logprobs,
)
return data
@ -1311,7 +1366,10 @@ class CompletionRequest(OpenAIBaseModel):
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.")
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@ -2138,7 +2196,15 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError("Stream options can only be defined when `stream=True`.")
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data
@ -2351,7 +2417,15 @@ class TranslationRequest(OpenAIBaseModel):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError("Stream options can only be defined when `stream=True`.")
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data

View File

@ -495,6 +495,7 @@ async def run_batch(
engine_client,
openai_serving_models,
request_logger=request_logger,
score_template=None,
)
if ("embed" in supported_tasks or enable_serving_reranking)
else None

View File

@ -417,8 +417,7 @@ class OpenAIServingChat(OpenAIServing):
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
assert len(generators) == 1
(result_generator,) = generators
@ -448,8 +447,7 @@ class OpenAIServingChat(OpenAIServing):
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
@ -682,7 +680,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parsers = [None] * num_choices
except Exception as e:
logger.exception("Error in tool parser creation.")
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
@ -1328,9 +1326,8 @@ class OpenAIServingChat(OpenAIServing):
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
@ -1354,8 +1351,7 @@ class OpenAIServingChat(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
assert final_res is not None

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
VLLMValidationError,
)
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
@ -247,8 +248,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators)
@ -308,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing):
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
@ -510,9 +509,8 @@ class OpenAIServingCompletion(OpenAIServing):
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
@ -660,8 +658,11 @@ class OpenAIServingCompletion(OpenAIServing):
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
raise VLLMValidationError(
"Unable to get tokenizer because "
"`skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
token = tokenizer.decode(token_id)
@ -720,6 +721,15 @@ class OpenAIServingCompletion(OpenAIServing):
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,

View File

@ -57,6 +57,7 @@ from vllm.entrypoints.openai.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest,
VLLMValidationError,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
@ -322,8 +323,10 @@ class OpenAIServing:
input_processor = self.input_processor
tokenizer = input_processor.tokenizer
if tokenizer is None:
raise ValueError(
"You cannot use beam search when `skip_tokenizer_init=True`"
raise VLLMValidationError(
"You cannot use beam search when `skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore
@ -706,8 +709,7 @@ class OpenAIServing:
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def _collect_batch(
self,
@ -738,14 +740,43 @@ class OpenAIServing:
return None
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response(e)
def create_error_response(
self,
message: str,
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> ErrorResponse:
exc: Exception | None = None
if isinstance(message, Exception):
exc = message
from vllm.entrypoints.openai.protocol import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
@ -753,18 +784,27 @@ class OpenAIServing:
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
error=ErrorInfo(
message=message,
type=err_type,
code=status_code.value,
param=param,
)
)
def create_streaming_error_response(
self,
message: str,
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> str:
json_str = json.dumps(
self.create_error_response(
message=message, err_type=err_type, status_code=status_code
message=message,
err_type=err_type,
status_code=status_code,
param=param,
).model_dump()
)
return json_str
@ -825,6 +865,7 @@ class OpenAIServing:
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
@ -991,11 +1032,13 @@ class OpenAIServing:
ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise ValueError(
raise VLLMValidationError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input."
f"Please reduce the length of the input.",
parameter="input_tokens",
value=token_num,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
@ -1017,20 +1060,24 @@ class OpenAIServing:
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= self.max_model_len:
raise ValueError(
raise VLLMValidationError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages."
"the input messages.",
parameter="input_tokens",
value=token_num,
)
if max_tokens is not None and token_num + max_tokens > self.max_model_len:
raise ValueError(
raise VLLMValidationError(
"'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is "
f"{self.max_model_len} tokens and your request has "
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
f" - {token_num})."
f" - {token_num}).",
parameter="max_tokens",
value=max_tokens,
)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

View File

@ -94,6 +94,7 @@ from vllm.entrypoints.openai.protocol import (
ResponsesResponse,
ResponseUsage,
StreamingResponsesResponse,
VLLMValidationError,
)
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
@ -271,6 +272,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error",
message=error_message,
status_code=HTTPStatus.BAD_REQUEST,
param="input",
)
return None
@ -282,6 +284,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error",
message="logprobs are not supported with gpt-oss models",
status_code=HTTPStatus.BAD_REQUEST,
param="logprobs",
)
if request.store and not self.enable_store and request.background:
return self.create_error_response(
@ -294,6 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
"the vLLM server."
),
status_code=HTTPStatus.BAD_REQUEST,
param="background",
)
if request.previous_input_messages and request.previous_response_id:
return self.create_error_response(
@ -301,6 +305,7 @@ class OpenAIServingResponses(OpenAIServing):
message="Only one of `previous_input_messages` and "
"`previous_response_id` can be set.",
status_code=HTTPStatus.BAD_REQUEST,
param="previous_response_id",
)
return None
@ -457,8 +462,7 @@ class OpenAIServingResponses(OpenAIServing):
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
assert len(generators) == 1
(result_generator,) = generators
@ -546,7 +550,7 @@ class OpenAIServingResponses(OpenAIServing):
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response(e)
async def _make_request(
self,
@ -630,8 +634,7 @@ class OpenAIServingResponses(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
# NOTE: Implementation of stauts is still WIP, but for now
# we guarantee that if the status is not "completed", it is accurate.
@ -1074,7 +1077,7 @@ class OpenAIServingResponses(OpenAIServing):
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
response = self.create_error_response(e)
finally:
new_event_signal.set()
@ -1099,7 +1102,7 @@ class OpenAIServingResponses(OpenAIServing):
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
response = self.create_error_response(e)
if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
@ -1116,7 +1119,11 @@ class OpenAIServingResponses(OpenAIServing):
starting_after: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]:
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
raise VLLMValidationError(
f"Unknown response_id: {response_id}",
parameter="response_id",
value=response_id,
)
event_deque, new_event_signal = self.event_store[response_id]
start_index = 0 if starting_after is None else starting_after + 1
@ -1172,6 +1179,7 @@ class OpenAIServingResponses(OpenAIServing):
return self.create_error_response(
err_type="invalid_request_error",
message="Cannot cancel a synchronous response.",
param="response_id",
)
# Update the status to "cancelled".
@ -1191,6 +1199,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error",
message=f"Response with id '{response_id}' not found.",
status_code=HTTPStatus.NOT_FOUND,
param="response_id",
)
def _make_store_not_supported_error(self) -> ErrorResponse:
@ -1203,6 +1212,7 @@ class OpenAIServingResponses(OpenAIServing):
"starting the vLLM server."
),
status_code=HTTPStatus.BAD_REQUEST,
param="store",
)
async def _process_simple_streaming_events(

View File

@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
TranslationSegment,
TranslationStreamResponse,
UsageInfo,
VLLMValidationError,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@ -259,7 +260,11 @@ class OpenAISpeechToText(OpenAIServing):
)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
raise VLLMValidationError(
"Maximum file size exceeded",
parameter="audio_filesize_mb",
value=len(audio_data) / 1024**2,
)
with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a
@ -287,12 +292,18 @@ class OpenAISpeechToText(OpenAIServing):
)
if request.response_format == "verbose_json":
if not isinstance(prompt, dict):
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}")
raise VLLMValidationError(
"Expected prompt to be a dict",
parameter="prompt",
value=type(prompt).__name__,
)
prompt_dict = cast(dict, prompt)
decoder_prompt = prompt.get("decoder_prompt")
if not isinstance(decoder_prompt, str):
raise ValueError(
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}"
raise VLLMValidationError(
"Expected decoder_prompt to be str",
parameter="decoder_prompt",
value=type(decoder_prompt).__name__,
)
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
"<|notimestamps|>", "<|0.00|>"
@ -412,7 +423,7 @@ class OpenAISpeechToText(OpenAIServing):
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
return self.create_error_response(e)
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try:
@ -448,8 +459,7 @@ class OpenAISpeechToText(OpenAIServing):
for i, prompt in enumerate(prompts)
]
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
if request.stream:
return stream_generator_method(
@ -523,8 +533,7 @@ class OpenAISpeechToText(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def _speech_to_text_stream_generator(
self,
@ -634,9 +643,8 @@ class OpenAISpeechToText(OpenAIServing):
)
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"

View File

@ -52,6 +52,7 @@ class ServingScores(OpenAIServing):
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
) -> None:
super().__init__(
@ -60,6 +61,7 @@ class ServingScores(OpenAIServing):
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.score_template = score_template
async def _embedding_score(
self,
@ -169,6 +171,7 @@ class ServingScores(OpenAIServing):
data_2=data_2,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
score_template=self.score_template,
)
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt)
if request.mm_processor_kwargs is not None:

View File

@ -12,6 +12,7 @@ import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike
@ -162,8 +163,9 @@ class BaseRenderer(ABC):
) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
@ -396,10 +398,12 @@ class CompletionRenderer(BaseRenderer):
) -> TokensPrompt:
"""Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise ValueError(
raise VLLMValidationError(
f"This model's maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages."
"Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(token_ids),
)
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)

View File

@ -11,9 +11,11 @@ from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageEmbedsParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
ChatTemplateResolutionError,
MultiModalItemTracker,
_ContentPart,
_parse_chat_message_content_part,
apply_hf_chat_template,
)
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
@ -139,10 +141,8 @@ def _parse_score_content(
return next(iter(mm_placeholder_storage.values()))[0]
def apply_score_template(
model_config: ModelConfig,
prompt_1: str,
prompt_2: str,
def _apply_model_score_template(
model_config: ModelConfig, prompt_1: str, prompt_2: str
) -> str:
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
from vllm.model_executor.model_loader import get_model_cls
@ -181,6 +181,7 @@ def get_score_prompt(
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
score_template: str | None = None,
) -> tuple[str, TokensPrompt]:
prompt_1, prompt_2, mm_data = parse_score_data(
data_1,
@ -190,19 +191,48 @@ def get_score_prompt(
from vllm.model_executor.model_loader import get_model_cls
model = get_model_cls(model_config)
if supports_score_template(model):
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
elif model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
def default_tokenizer_encode():
if supports_score_template(model):
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
else:
if model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` models defaults to not using pad_token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if score_template is None:
full_prompt, prompt_inputs = default_tokenizer_encode()
else:
# `llm as reranker` models defaults to not using pad_token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = apply_hf_chat_template(
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
score_template,
tools=None,
model_config=model_config,
)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
except ChatTemplateResolutionError:
full_prompt, prompt_inputs = default_tokenizer_encode()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])

View File

@ -186,6 +186,7 @@ class DPMetadata:
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
"""
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
@ -193,7 +194,6 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass

View File

@ -12,7 +12,6 @@ from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import (
get_lora_id,
is_base_embeddding_weights,
is_regex_target_modules,
parse_fine_tuned_lora_name,
)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -201,37 +200,13 @@ class LoRAModel:
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules
):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
lora_file_path = (
lora_bin_file_path
if os.path.isfile(lora_bin_file_path)
else lora_pt_file_path
)
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
check_unexpected_modules(tensors)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")

View File

@ -11,9 +11,11 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
if user_defined_config_folder is not None:
# Avoid optimizing for the batch invariant case. Use default config
if user_defined_config_folder is not None and not is_batch_invariant:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")
@ -203,11 +206,14 @@ def get_lora_op_configs(
# default config
default = {}
if op_type == "shrink":
split_k = 64 if batch < 128 else 8
if is_batch_invariant:
split_k = 1
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": 64 if batch < 128 else 8,
"split_k": split_k,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,

View File

@ -5,7 +5,6 @@ import os
from typing import TYPE_CHECKING, Optional
import huggingface_hub
import regex as re
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
@ -186,39 +185,6 @@ def is_base_embeddding_weights(name: str) -> bool:
return name.endswith(embedding_suffixes)
def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: set[str]
) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_set):
return set(sub_list).issubset(full_set)
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_supported_lora_modules(model: nn.Module) -> list[str]:
"""
In vLLM, all linear layers support LoRA.

View File

@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p
BCx, _ = self.in_proj(hidden_states)
@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = []

View File

@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]

View File

@ -88,6 +88,26 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
}
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
from vllm.config.pooler import PoolingTypeStr
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
vllm_config.model_config.pooler_config.pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@ -509,6 +529,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

View File

@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol):
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
requires_raw_input_tokens: ClassVar[bool] = False
"""
A flag that indicates this model processes input id tokens
in their raw form and not input embeddings.
"""
merge_by_field_config: ClassVar[bool | None] = None
"""
[DEPRECATED] A flag that indicates which implementation of
@ -324,6 +330,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
return getattr(model, "supports_multimodal_raw_input_only", False)
def requires_raw_input_tokens(model: type[object] | object) -> bool:
return getattr(model, "requires_raw_input_tokens", False)
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False)

View File

@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
@ -167,7 +166,6 @@ class Jais2Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
@ -304,17 +302,12 @@ class Jais2Model(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()

View File

@ -57,7 +57,14 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .interfaces_base import attn_type
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@ -698,3 +705,17 @@ class LlamaForCausalLM(
name = name.replace(item, mapping[item])
return name, loaded_weight
@attn_type("encoder_only")
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.
pass
@attn_type("encoder_only")
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.
pass

View File

@ -46,6 +46,7 @@ from .interfaces import (
has_noops,
is_attention_free,
is_hybrid,
requires_raw_input_tokens,
supports_cross_encoding,
supports_mamba_prefix_caching,
supports_multimodal,
@ -203,6 +204,7 @@ _EMBEDDING_MODELS = {
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
# Multiple models share the same architecture, so we include them all
@ -246,6 +248,11 @@ _CROSS_ENCODER_MODELS = {
"bert_with_rope",
"GteNewForSequenceClassification",
),
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
"LlamaBidirectionalForSequenceClassification": (
"llama",
"LlamaBidirectionalForSequenceClassification",
),
"ModernBertForSequenceClassification": (
"modernbert",
"ModernBertForSequenceClassification",
@ -259,8 +266,6 @@ _CROSS_ENCODER_MODELS = {
"roberta",
"RobertaForSequenceClassification",
),
# [Auto-converted (see adapters.py)]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
}
_MULTIMODAL_MODELS = {
@ -418,6 +423,7 @@ _MULTIMODAL_MODELS = {
),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
"VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501
# [Encoder-decoder]
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
@ -535,6 +541,7 @@ class _ModelInfo:
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool
supports_multimodal_encoder_tp_data: bool
supports_pp: bool
has_inner_state: bool
@ -558,6 +565,7 @@ class _ModelInfo:
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model
),
requires_raw_input_tokens=requires_raw_input_tokens(model),
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
model
),

View File

@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
@ -32,6 +31,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.transformers_utils.config import is_rope_parameters_nested
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -207,7 +207,7 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
if rope_parameters:
# Nest rope_parameters if not nested already to simplify logic
if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
if not is_rope_parameters_nested(rope_parameters):
rope_parameters = {"": rope_parameters}
return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
return True

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
self,
audio_length: int,
) -> int:
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
audio_length, self.sampling_rate
)
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1
# pad if necessary
audio = self._audio_processor.pad(audio, self.sampling_rate)
# TODO(Patrick) - remove once mistral-common is bumped
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
remapping_rules = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(r"mm_whisper_embeddings\.(.*)", r"\1"),
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
(
@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
def llm_weights_generator():
nonlocal loaded_weights
for name, w in weights:
is_encoder = (
name.startswith("mm_whisper_embeddings")
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
and not name.startswith(
"mm_whisper_embeddings.audio_language_projection"
is_encoder = False
for k in [
"mm_whisper_embeddings",
"mm_streams_embeddings.embedding_module",
]:
is_encoder |= (
name.startswith(k)
and not name.startswith(f"{k}.tok_embeddings")
and not name.startswith(f"{k}.audio_language_projection")
)
)
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
mistral_remapping = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
r"whisper_encoder.conv1.\1",
@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
r"whisper_encoder.conv2.\1",
),
(
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
r"whisper_encoder.conv1.\1",
), # noqa: E501
(
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
r"whisper_encoder.conv2.\1",
), # noqa: E501
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Mapping
import torch
from vllm.config.vllm import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.voxtral import (
VoxtralDummyInputsBuilder,
VoxtralForConditionalGeneration,
VoxtralMultiModalProcessor,
VoxtralProcessingInfo,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalKwargsOptionalItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .utils import (
_flatten_embeddings,
)
logger = init_logger(__name__)
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> None:
# streaming can't make use of a cache yet
super().__init__(info, dummy_inputs, cache=None)
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
# there are no placeholder audio tokens for streaming
# so we need to build the place placeholder positions manually
# in streaming there is always only one audio input
audios = mm_kwargs.get("audio", [])
assert len(audios) == 1, (
f"Expected only one audio input for streaming, got {mm_kwargs=}"
)
tokenizer = self.info.get_tokenizer()
audio_config = tokenizer.instruct.audio_encoder.audio_config
num_audio_samples = audios[0]["audio_arrays"].data.shape[0]
length = audio_config.num_audio_tokens(num_audio_samples)
features_info = PlaceholderFeaturesInfo(
modality="audio",
item_idx=0,
start_idx=0,
tokens=length
* [0], # only used for length computation, so we can take dummy inputs
is_embed=None,
)
return prompt_ids, {"audio": [features_info]}
class TimeEmbedding(torch.nn.Module):
"""Sinusoidal Embedding for encoding time"""
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = torch.exp(
-math.log(self.theta)
* torch.arange(self.dim // 2).float()
/ (self.dim // 2)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, t: torch.Tensor) -> torch.Tensor:
t = t[..., None] # (B,) -> (B, 1) or (B, T) -> (B, T, 1)
inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype)
emb = (
t * inv_freq
) # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
@MULTIMODAL_REGISTRY.register_processor(
VoxtralStreamingMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
requires_raw_input_tokens = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.time_embedding: TimeEmbedding = TimeEmbedding(
dim=self.config.text_config.hidden_size
)
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
_n_delay_tokens = (
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
)
assert _n_delay_tokens.is_integer(), (
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
)
self.n_delay_tokens = int(_n_delay_tokens)
@property
def audio_config(self):
return self.tokenizer.instruct.audio_encoder.audio_config
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input"""
# for streaming we simply flatten the multimodal embeddings
# to be in tensor format, we treat the input ids later
assert multimodal_embeddings is not None
assert len(multimodal_embeddings) > 0, (
"For streaming you must provide a multimodal_embedding at every step."
)
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
return mm_embeds_flat
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
assert inputs_embeds is not None
assert input_ids is not None
pool_size = self.config.audio_config.block_pool_size
inputs_embeds = inputs_embeds.view(
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
)
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers(
inputs_embeds
)
num_tokens, audio_hidden_size = audio_hidden_states.shape
assert num_tokens % self.downsample_factor == 0
audio_hidden_states = audio_hidden_states.reshape(
num_tokens // self.downsample_factor,
audio_hidden_size * self.downsample_factor,
)
audio_text_embeds = self.audio_language_adapter(audio_hidden_states)
text_embeds = self.language_model.embed_input_ids(input_ids)
# sum pool text and audio embeddings
inputs_embeds = audio_text_embeds + text_embeds
time_tensor = torch.tensor(
[self.n_delay_tokens],
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def embed_multimodal(
self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
assert audio_inputs is not None, (
"For streaming you must provide an audio input at every step."
)
multiple_of = self.audio_config.raw_audio_length_per_tok
assert all(
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
), (
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
f" one is {this_audio} with {(this_audio / multiple_of)=}."
)
mel_features = [
self.whisper_encoder.compute_whisper_melspec(audio).to(
self.whisper_encoder.dtype
)
for audio in audio_inputs
]
seq_lens = [mel.shape[1] for mel in mel_features]
# [total_num_20ms_frames, hidden_size]
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
mel_features
)[0]
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
audio_embeddings_per_sample = audio_embeddings.split(
[s // conv_stride for s in seq_lens], dim=0
)
# audio_embeddings per sample need to be divisible by 4
pool_size = self.config.audio_config.block_pool_size
assert all(
(this_shape := sample.shape[0]) % pool_size == 0
for sample in audio_embeddings_per_sample
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
audio_embeddings_per_sample = [
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
for e in audio_embeddings_per_sample
]
return audio_embeddings_per_sample

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from functools import partial
from typing import Annotated, Literal, cast
import numpy as np
@ -16,7 +18,10 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.backends.abstract import (
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS,
WhisperAttentionWithBlockPooling,
WhisperCausalConv1d,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
@ -64,67 +74,11 @@ from .utils import (
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh",
}
class WhisperPosEmbedType(enum.Enum):
SINUSOIDAL = "sinusoidal"
NOPE = "nope"
LEARNED = "learned"
class WhisperAudioInputs(TensorSchema):
@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
num_heads: int,
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
per_layer_sliding_window: int | None = None,
block_pool_size: int = 1,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention(
if block_pool_size > 1:
attn_cls = partial(
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
)
else:
attn_cls = Attention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
per_layer_sliding_window=per_layer_sliding_window,
)
def _init_qkv(
@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
is_causal = getattr(config, "is_causal", False)
sliding_window = getattr(config, "sliding_window", None)
block_pool_size = getattr(config, "block_pool_size", 1)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
attn_type=AttentionType.ENCODER,
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
block_pool_size=block_pool_size,
per_layer_sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.pos_embed_type = WhisperPosEmbedType(
getattr(config, "pos_embed", "sinusoidal")
)
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
is_causal = getattr(config, "is_causal", False)
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(
@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
)
self.layer_norm = nn.LayerNorm(config.d_model)
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
)
with (
torch.no_grad(),
maybe_fp32_init_ctx,
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
)
elif self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32)
if init_in_fp32
else nullcontext()
)
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(
self.max_source_positions, embed_dim
)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
)
def forward_conv(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor:
hidden_states = []
input_is_batched = False
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.transpose(-1, -2)
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
embeds.dtype
)
if self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
embeds = embeds.transpose(-1, -2)
embeds = (
embeds + self.embed_positions.weight[: embeds.size(-2), :]
).to(embeds.dtype)
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
else:
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
else:
hidden_states = torch.stack(hidden_states, dim=0)
return hidden_states
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = self.forward_conv(input_features)
return self.forward_layers(hidden_states)
class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import functools
import math
from dataclasses import replace
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.kv_cache_interface import AttentionSpec
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh",
}
def _pad1d(
x: torch.Tensor,
paddings: tuple[int, int],
mode: str = "constant",
value: float = 0.0,
) -> torch.Tensor:
"""Tiny wrapper around F.pad, just to allow for
reflect padding on small input.
If this is the case, we insert extra 0 padding
to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class WhisperCausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._stride = self.stride[0]
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
self._padding_total = self._effective_kernel_size - self._stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_frames = (
x.shape[-1] - self._effective_kernel_size + self._padding_total
) / self._stride + 1
target_length = (math.ceil(n_frames) - 1) * self._stride + (
self._effective_kernel_size - self._padding_total
)
extra_padding = target_length - x.shape[-1]
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
return super().forward(x)
@functools.lru_cache
def create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend: AttentionBackend, block_pool_size: int
) -> type[AttentionBackend]:
prefix = "WhisperAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
kv_cache_spec = replace(
kv_cache_spec,
block_size=kv_cache_spec.block_size * block_pool_size,
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
)
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size
original_slot_mapping = common_attn_metadata.slot_mapping
common_prefix_len *= block_pool_size
new_common_attn_metadata.slot_mapping = (
(
original_slot_mapping.unsqueeze(1) * block_pool_size
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
)
.flatten()
.clamp(min=-1)
)
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
head_size,
cache_dtype_str: (
2,
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
},
)
return attn_backend
class WhisperAttentionWithBlockPooling(Attention):
"""Attention layer with block pooling."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
block_pool_size: int = 1,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
self.block_pool_size = block_pool_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend, block_pool_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=prefix,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
**extra_impl_args,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig):
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
assert isinstance(kv_cache_spec, AttentionSpec)
kv_cache_spec = replace(
kv_cache_spec,
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
)
return kv_cache_spec

View File

@ -111,11 +111,16 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
def encode_base64(
self,
media: tuple[npt.NDArray, int],
*,
audio_format: str = "WAV",
) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
soundfile.write(buffer, audio, sr, format=audio_format)
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")

View File

@ -8,8 +8,12 @@ import pybase64
import torch
from PIL import Image
from vllm.logger import init_logger
from .base import MediaIO, MediaWithBytes
logger = init_logger(__file__)
def rescale_image_size(
image: Image.Image, size_factor: float, transpose: int = -1
@ -104,8 +108,17 @@ class ImageMediaIO(MediaIO[Image.Image]):
self,
media: Image.Image,
*,
image_format: str = "JPEG",
image_format: str | None = None,
) -> str:
if image_format is None:
logger.warning_once(
"The default format of `ImageMediaIO.encode_base64` will be changed "
'from "JPEG" to "PNG" in v0.15 to avoid lossy compression. '
"To continue using the old default, "
'pass `format="JPEG"` explicitly to silence this warning.'
)
image_format = "JPEG"
image = media
with BytesIO() as buffer:

View File

@ -3,6 +3,7 @@
import asyncio
import atexit
import mimetypes
from collections.abc import Generator, Set
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
@ -357,17 +358,31 @@ class MediaConnector:
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
*,
format: str = "WAV",
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
return audio_io.encode_base64((audio, sampling_rate), audio_format=format)
def encode_audio_url(
audio: np.ndarray,
sampling_rate: int,
*,
format: str = "WAV",
) -> str:
"""Encode audio as a data URL."""
audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
return f"data:{mimetype};base64,{audio_b64}"
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
format: str | None = None,
) -> str:
"""
Encode a pillow image to base64 format.
@ -378,10 +393,45 @@ def encode_image_base64(
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
def encode_image_url(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "PNG",
) -> str:
"""
Encode a pillow image as a data URL.
By default, the image is converted into RGB format before being encoded.
"""
image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
mimetype = mimetypes.types_map.get("." + format.lower(), "image")
return f"data:{mimetype};base64,{image_b64}"
def encode_video_base64(
frames: npt.NDArray,
*,
format: str = "JPEG",
) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
return video_io.encode_base64(frames, video_format=format)
def encode_video_url(
frames: npt.NDArray,
*,
format: str = "JPEG",
) -> str:
video_b64 = encode_video_base64(frames, format=format)
if format.lower() == "jpeg":
mimetype = "video/jpeg"
else:
mimetype = mimetypes.types_map.get("." + format.lower(), "video")
return f"data:{mimetype};base64,{video_b64}"
def argsort_mm_positions(

View File

@ -156,7 +156,9 @@ class XPUPlatform(Platform):
if vllm_config.lora_config is not None:
compilation_config.mode = CompilationMode.NONE
# decrease triton kernel compilation scratch space for speculative decoding
if vllm_config.speculative_config is not None:
os.environ["IGC_ForceOCLSIMDWidth"] = "16" # noqa: SIM112
# check and update parallel config
parallel_config = vllm_config.parallel_config
# Only override worker_cls if it's still the default "auto"

View File

@ -131,78 +131,105 @@ class MistralToolParser(ToolParser):
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
Extract the tool calls from a complete model response.
Content and tool calls formatting depends on the Mistral's tokenizer version
used to train the model:
- < v11: `content[BOT] [{tool_call1},{tool_call2}]`
- >= v11: `content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}`
with [BOT] the tool call token.
Note:
For tokenizer versions >= v11, tool calls with arguments wrongly formatted
are still returned as tool calls. This is to allow the model to know it
tried to make a tool call. It reduces chance of another failure and
prevents that the context is filled with tool calls wrongly placed in
assistant message contents.
"""
# case -- if a tool call token is not present, return a text response
# If the tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
content_and_raw_tool_calls = model_output.split(self.bot_token)
content = content_and_raw_tool_calls[0]
raw_tool_calls = content_and_raw_tool_calls[1:]
try:
# >= v11: content[BOT]tool_name1{args_call1}[BOT]tool_name2{args_call2}
if not self._is_pre_v11:
tool_calls = []
for raw_tool_call in raw_tool_calls:
if "{" not in raw_tool_call:
continue
end_name = raw_tool_call.find("{")
tool_name, args = (
raw_tool_call[:end_name],
raw_tool_call[end_name:],
)
tool_calls.append({"name": tool_name, "arguments": args})
# < v11: content[BOT] [{tool_call1},{tool_call2}]
else:
if len(raw_tool_calls) != 1:
raise ValueError(
"Only one BOT token should have been outputted, "
f"but got {model_output}."
)
stringified_tool_calls = raw_tool_calls[0].strip()
try:
if not self._is_pre_v11:
function_call_arr = []
for single_tool_content in model_output.split(self.bot_token):
if "{" not in single_tool_content:
continue
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
tool_calls = json.loads(stringified_tool_calls)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's an easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# Tool Call
tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
try:
raw_tool_call = self.tool_call_regex.findall(
stringified_tool_calls
)[0]
tool_calls = json.loads(raw_tool_call)
except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response: {e}")
# If raw decoding and decoding post regex rule fails, then just
# return content.
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=stringified_tool_calls,
)
else:
tool_calls = [
{
"name": tool_call["name"],
"arguments": json.dumps(
tool_call["arguments"], ensure_ascii=False
),
),
)
for raw_function_call in function_call_arr
]
}
for tool_call in tool_calls
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None,
mistral_tool_calls: list[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=tool_call["arguments"],
),
)
for tool_call in tool_calls
]
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=tool_content
)
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=mistral_tool_calls,
content=content if len(content) > 0 else None,
)
def extract_tool_calls_streaming(
self,

View File

@ -15,7 +15,6 @@ from huggingface_hub import (
)
from packaging.version import Version
from transformers import GenerationConfig, PretrainedConfig
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
@ -44,6 +43,16 @@ from .repo_utils import (
with_retry,
)
try:
# Transformers v5
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
except ImportError:
# Transformers v4
from transformers.configuration_utils import (
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
)
if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
else:
@ -104,6 +113,14 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
}
def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
"""Check if rope_parameters is nested by layer types."""
# Cannot be nested if rope_parameters is empty
if not rope_parameters:
return False
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
class HFConfigParser(ConfigParserBase):
def parse(
self,
@ -313,19 +330,25 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
rope_theta = getattr_iter(config, names, None, warn=True)
names = ["partial_rotary_factor", "rotary_pct", "rotary_emb_fraction"]
partial_rotary_factor = getattr_iter(config, names, None, warn=True)
ompe = getattr(config, "original_max_position_embeddings", None)
if Version(version("transformers")) < Version("5.0.0.dev0"):
# Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling
if (
rope_theta is not None or partial_rotary_factor is not None
rope_theta is not None
or partial_rotary_factor is not None
or ompe is not None
) and not getattr(config, "rope_parameters", None):
config.rope_parameters = {"rope_type": "default"}
# Patch legacy fields into rope_parameters
if rope_theta is not None:
config.rope_parameters["rope_theta"] = rope_theta
if partial_rotary_factor is not None:
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
if ompe is not None:
config.rope_parameters["original_max_position_embeddings"] = ompe
elif rope_theta is not None or getattr(config, "rope_parameters", None):
# Transformers v5 installed
# Patch these fields in case they used non-standard names
@ -341,12 +364,8 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
if getattr(config, "rope_parameters", None) is None:
return
# Add original_max_position_embeddings if present
if ompe := getattr(config, "original_max_position_embeddings", None):
config.rope_parameters["original_max_position_embeddings"] = ompe
# Handle nested rope_parameters in interleaved sliding attention models
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
if is_rope_parameters_nested(config.rope_parameters):
for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type)
else:

View File

@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
downsample_factor = downsample_args["downsample_factor"]
# make sure that k/v blocks can be allocated with
# unified k/v cache class and pool whisper k/v cache blocks
# with downsample_factor:1 ratio
if encoder_args.get("causal"):
block_pool_size = downsample_factor
config["projection_size"] = downsample_factor * encoder_args["dim"]
else:
block_pool_size = 1
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
if _maybe_sliding_window is None:
sliding_window = None
elif _maybe_sliding_window.isdigit():
sliding_window = int(_maybe_sliding_window)
else:
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
architecture = (
"VoxtralStreamingGeneration"
if encoder_args.get("causal")
else "VoxtralForConditionalGeneration"
)
quant_config = config.get("quantization_config")
config = {
"model_type": "whixtral",
"architectures": ["VoxtralForConditionalGeneration"],
"model_type": "voxtral",
"architectures": [architecture],
"text_config": PretrainedConfig.from_dict(config),
"audio_config": WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_args["downsample_factor"],
downsample_factor=downsample_factor,
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False),
sliding_window=sliding_window,
block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
),
}
if quant_config:

View File

@ -3,17 +3,11 @@
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
@dataclass
class Mamba1AttentionMetadata:
query_start_loc_p: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
pass
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
metadata_cls = Mamba1AttentionMetadata
supports_update_block_table: bool = False

View File

@ -1,19 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
from dataclasses import dataclass
from dataclasses import dataclass, replace
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
chunk_size: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
seq_idx_p: torch.Tensor | None
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
prep_initial_states: bool = False
chunk_size: int = 0
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None
state_indices_tensor: torch.Tensor # shape: [batch,]
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
supports_update_block_table: bool = True
metadata_cls = Mamba2AttentionMetadata
def __init__(
self,
@ -150,87 +128,93 @@ class Mamba2AttentionMetadataBuilder(
"chunk_size needs to be set in the model config for Mamba2 models"
)
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
seq_lens = common_attn_metadata.seq_lens
common = self._compute_common_metadata(common_attn_metadata)
query_start_loc_p = None
seq_idx_p = None
cu_chunk_seqlen_p = None
last_chunk_indices_p = None
# Need flags to indicate if there are initial states
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# Additional cache-related varaiables:
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if num_prefills > 0:
# [batch,]
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
if common.num_prefills > 0:
prep_initial_states = (
torch.any(common.has_initial_states_p).item()
if common.has_initial_states_p is not None
else False
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
@ -239,137 +223,33 @@ class Mamba2AttentionMetadataBuilder(
- num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
return replace(
common,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
has_initial_states_p=has_initial_states_p,
seq_idx_p=seq_idx_p,
state_indices_tensor=state_indices_tensor,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import copy
from dataclasses import dataclass
from typing import ClassVar, TypeVar
import torch
@ -9,20 +11,52 @@ import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
@dataclass
class BaseMambaAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_reqs: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
metadata_cls: type[M]
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
supports_update_block_table: bool = True
def __init__(
self,
@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
return self.build(0, m)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return self._compute_common_metadata(common_attn_metadata)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)
def _compute_common_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens = None
num_computed_tokens_p = None
# for prefix caching
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
block_idx_last_computed_token = None
block_idx_last_scheduled_token = None
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return self.metadata_cls(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -541,11 +541,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -689,6 +684,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()
# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
@ -701,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)
# Prepare context prefills
@ -720,7 +716,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)
prefill.prefill_main = self._fi_prefill_main
@ -973,7 +969,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@ -1384,15 +1379,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out
def _run_prefill_new_tokens_fa(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1407,23 +1395,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_fi(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = prefill.prefill_main.run(
q=q,
k=k,
@ -1436,18 +1412,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_new_tokens_cudnn(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1469,19 +1437,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return output
def _run_prefill_context_chunk_fa(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
@ -1496,22 +1454,10 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_context_chunk_fi(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
@ -1523,20 +1469,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
@ -1556,28 +1494,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def _run_prefill_new_tokens_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1604,15 +1528,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return ret
def _run_prefill_context_chunk_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek
@ -1629,13 +1546,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)
attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
@ -1788,7 +1698,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
fp8_attention: bool,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
@ -1827,7 +1736,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1856,7 +1764,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
fp8_attention: bool,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
@ -1933,7 +1840,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)
if output is None:
@ -1964,7 +1870,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
output: torch.Tensor,
fp8_attention: bool = False,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
@ -1984,7 +1889,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k=k,
v=v,
return_softmax_lse=has_context,
fp8_attention=fp8_attention,
)
if has_context:
@ -1997,12 +1901,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
# unpad if necessary
@ -2123,7 +2026,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
)
if has_decode:

View File

@ -2,15 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
# For causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
pass
class ShortConvAttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
if num_prefills > 0:
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata(
query_start_loc=query_start_loc,
state_indices_tensor=state_indices_tensor,
has_initial_states_p=has_initial_states_p,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
metadata_cls = ShortConvAttentionMetadata

View File

@ -835,6 +835,15 @@ def subclass_attention_backend(
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,

View File

@ -75,6 +75,12 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
# The user-provided request ID. This field is set internally,
# copied from the provided request_id that's originally assigned
# to the request_id field, see InputProcessor.assign_request_id().
# Used in outputs and to support abort(req_id, internal=False).
external_req_id: str | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""

View File

@ -290,12 +290,15 @@ class AsyncLLM(EngineClient):
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
@ -314,6 +317,11 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id)
# Use cloned params that may have been updated in process_inputs()
params = request.params
@ -325,7 +333,7 @@ class AsyncLLM(EngineClient):
assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
parent_request = ParentRequest(request)
for idx in range(parent_params.n):
request_id, child_params = parent_request.get_child_info(idx)
child_request = request if idx == parent_params.n - 1 else copy(request)
@ -396,6 +404,7 @@ class AsyncLLM(EngineClient):
"prompt logprobs"
)
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
@ -446,7 +455,8 @@ class AsyncLLM(EngineClient):
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
@ -465,7 +475,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
@ -541,13 +552,15 @@ class AsyncLLM(EngineClient):
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str | Iterable[str]) -> None:
async def abort(
self, request_id: str | Iterable[str], internal: bool = False
) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = (
(request_id,) if isinstance(request_id, str) else as_list(request_id)
)
all_request_ids = self.output_processor.abort_requests(request_ids)
all_request_ids = self.output_processor.abort_requests(request_ids, internal)
await self.engine_core.abort_requests_async(all_request_ids)
if self.log_requests:
@ -581,7 +594,7 @@ class AsyncLLM(EngineClient):
if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids)
await self.abort(request_ids, internal=True)
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests():
@ -633,6 +646,7 @@ class AsyncLLM(EngineClient):
TODO: Remove truncate_prompt_tokens in v0.15.
"""
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
@ -687,7 +701,8 @@ class AsyncLLM(EngineClient):
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
@ -706,7 +721,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e

View File

@ -21,7 +21,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import (
@ -424,6 +424,19 @@ class InputProcessor:
return mm_hash
return f"{lora_request.lora_name}:{mm_hash}"
@staticmethod
def assign_request_id(request: EngineCoreRequest):
"""Replace the externally supplied request ID with an internal request ID
that adds 8 random characters in order to ensure uniquness.
"""
if request.external_req_id is not None:
raise ValueError(
"The external_req_id field should not be set on EngineCoreRequests"
" passed to vLLM; use the request_id field."
)
request.external_req_id = request.request_id
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
def process_inputs(
self,
request_id: str,

View File

@ -213,10 +213,10 @@ class LLMEngine:
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None:
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
request_ids = self.output_processor.abort_requests(request_ids, internal)
self.engine_core.abort_requests(request_ids)
def add_request(
@ -238,6 +238,12 @@ class LLMEngine:
# Process raw inputs into the request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
@ -255,6 +261,8 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Use cloned params that may have been updated in process_inputs()
params = request.params
@ -268,7 +276,7 @@ class LLMEngine:
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
parent_req = ParentRequest(request)
for idx in range(n):
request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
@ -40,8 +41,9 @@ class RequestOutputCollector:
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
def __init__(self, output_kind: RequestOutputKind, request_id: str):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.request_id = request_id
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
@ -92,6 +94,7 @@ class RequestState:
def __init__(
self,
request_id: str,
external_req_id: str,
parent_req: ParentRequest | None,
request_index: int,
lora_request: LoRARequest | None,
@ -111,6 +114,7 @@ class RequestState:
temperature: float | None = None,
):
self.request_id = request_id
self.external_req_id = external_req_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_request = lora_request
@ -176,8 +180,10 @@ class RequestState:
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
assert request.external_req_id is not None
return cls(
request_id=request.request_id,
external_req_id=request.external_req_id,
parent_req=parent_req,
request_index=request_index,
lora_request=request.lora_request,
@ -235,10 +241,13 @@ class RequestState:
]
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
request_id = self.request_id
external_req_id = self.external_req_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)], finished
external_req_id,
[self._new_pooling_output(pooling_output)],
finished,
)
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
@ -246,19 +255,18 @@ class RequestState:
if self.parent_req is None:
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, output
)
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
if not outputs:
return None
external_req_id = self.parent_req.external_req_id
return self._new_request_output(
request_id, outputs, finished, kv_transfer_params
external_req_id, outputs, finished, kv_transfer_params
)
def _new_request_output(
self,
request_id: str,
external_req_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
@ -269,7 +277,7 @@ class RequestState:
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput(
request_id=request_id,
request_id=external_req_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
@ -288,7 +296,7 @@ class RequestState:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput(
request_id=request_id,
request_id=external_req_id, # request_id is what was provided externally
lora_request=self.lora_request,
prompt=self.prompt,
prompt_token_ids=prompt_token_ids,
@ -352,6 +360,7 @@ class OutputProcessor:
self.stream_interval = stream_interval
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
self._requests_drained = asyncio.Event()
@ -375,12 +384,41 @@ class OutputProcessor:
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
"""Abort a list of requests.
The request_ids may be either external request IDs (those passed to
InputProcessor.process_inputs()) or internal request IDs (those randomly
generated when creating the EngineCoreRequest).
If an external request ID is provided, and that external request ID
was used for multiple requests, all requests associated with that external
request ID are aborted.
In the case of parallel sampling, a request ID may be used to identify
a parent request, in which case the associated child requests are aborted
also.
"""
internal_req_ids = []
for request_id in request_ids:
if internal:
# Internal ID - this may be a parent request
internal_req_ids.append(request_id)
# Remove internal ID from the external->internal mapping
if req_state := self.request_states.get(request_id):
external_req_id = req_state.external_req_id
internal_ids = self.external_req_ids[external_req_id]
internal_ids.remove(request_id)
if not internal_ids:
del self.external_req_ids[external_req_id]
elif internal_ids := self.external_req_ids.pop(request_id, []):
# External ID - abort all requests in the external->internal mapping
internal_req_ids.extend(internal_ids)
request_ids_to_abort = []
for request_id in internal_req_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.request_finished(request_id, req_state.lora_name)
@ -404,7 +442,7 @@ class OutputProcessor:
# Abort children prior to removing the parent.
if parent.child_requests:
child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs)
child_reqs = self.abort_requests(child_reqs, internal=True)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
if not self.request_states:
@ -439,6 +477,9 @@ class OutputProcessor:
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
# Track the external_req_id -> [internal_req_id, ...] mapping
self.external_req_ids[req_state.external_req_id].append(request_id)
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
@ -522,6 +563,12 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
@ -597,7 +644,9 @@ class OutputProcessor:
)
# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param:

View File

@ -6,6 +6,7 @@ from typing import Optional, cast
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import IterationStats
@ -17,6 +18,7 @@ class ParentRequest:
"""
request_id: str
external_req_id: str
sampling_params: SamplingParams
# To track the completion of child requests
@ -31,8 +33,11 @@ class ParentRequest:
# To efficiently obtain child sampling params
cached_child_sampling_params: SamplingParams | None
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
self.request_id = request_id
def __init__(self, request: EngineCoreRequest) -> None:
assert request.external_req_id is not None
sampling_params = request.params
self.request_id = request.request_id
self.external_req_id = request.external_req_id
self.sampling_params = sampling_params
self.child_requests = set()
@ -96,7 +101,7 @@ class ParentRequest:
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
) -> tuple[list[CompletionOutput], bool]:
already_finished_and_returned: bool = False
if completion_output.finished():
if child_request_id in self.child_requests:
@ -118,7 +123,7 @@ class ParentRequest:
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
return outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(

Some files were not shown because too many files have changed in this diff Show More