Merge branch 'main' into woosuk/model-runner-v2

This commit is contained in:
Woosuk Kwon 2025-09-24 08:19:25 -07:00
commit ad2cf805ad
143 changed files with 4210 additions and 1566 deletions

View File

@ -0,0 +1,59 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Setup script for Prime-RL integration tests
# This script prepares the environment for running Prime-RL tests with nightly vLLM
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git"
PRIME_RL_DIR="${REPO_ROOT}/prime-rl"
echo "Setting up Prime-RL integration test environment..."
# Clean up any existing Prime-RL directory
if [ -d "${PRIME_RL_DIR}" ]; then
echo "Removing existing Prime-RL directory..."
rm -rf "${PRIME_RL_DIR}"
fi
# Install UV if not available
if ! command -v uv &> /dev/null; then
echo "Installing UV package manager..."
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
fi
# Clone Prime-RL repository at specific branch for reproducible tests
PRIME_RL_BRANCH="integ-vllm-main"
echo "Cloning Prime-RL repository at branch: ${PRIME_RL_BRANCH}..."
git clone --branch "${PRIME_RL_BRANCH}" --single-branch "${PRIME_RL_REPO}" "${PRIME_RL_DIR}"
cd "${PRIME_RL_DIR}"
echo "Setting up UV project environment..."
export UV_PROJECT_ENVIRONMENT=/usr/local
ln -s /usr/bin/python3 /usr/local/bin/python
# Remove vllm pin from pyproject.toml
echo "Removing vllm pin from pyproject.toml..."
sed -i '/vllm==/d' pyproject.toml
# Sync Prime-RL dependencies
echo "Installing Prime-RL dependencies..."
uv sync --inexact && uv sync --inexact --all-extras
# Verify installation
echo "Verifying installations..."
uv run python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
uv run python -c "import prime_rl; print('Prime-RL imported successfully')"
echo "Prime-RL integration test environment setup complete!"
echo "Running Prime-RL integration tests..."
export WANDB_MODE=offline # this makes this test not require a WANDB_API_KEY
uv run pytest -vs tests/integration/test_rl.py -m gpu
echo "Prime-RL integration tests completed!"

View File

@ -887,6 +887,8 @@ steps:
- tests/v1/test_external_lb_dp.py
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- vllm/v1/engine/
- vllm/v1/worker/
- tests/v1/worker/test_worker_memory_snapshot.py
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
@ -908,6 +910,7 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s models/multimodal/generation/test_maverick.py
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
- label: Plugin Tests (2 GPUs) # 40min
timeout_in_minutes: 60
@ -1042,3 +1045,15 @@ steps:
commands:
- pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
##### RL Integration Tests #####
- label: Prime-RL Integration Test # 15min
timeout_in_minutes: 30
optional: true
num_gpus: 2
working_dir: "/vllm-workspace"
source_file_dependencies:
- vllm/
- .buildkite/scripts/run-prime-rl-test.sh
commands:
- bash .buildkite/scripts/run-prime-rl-test.sh

View File

@ -449,7 +449,8 @@ async def benchmark(
def prepare_extra_body(request) -> dict:
extra_body = {}
# Add the schema to the extra_body
extra_body[request.structure_type] = request.schema
extra_body["structured_outputs"] = {}
extra_body["structured_outputs"][request.structure_type] = request.schema
return extra_body
print("Starting initial single prompt test run...")

View File

@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
but use different quantization strategies and backends.
"""
import nvtx
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
# Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size]
WEIGHT_SHAPES_MOE = {
"mixtral-8x7b": [
[8, 2, 4096, 14336],
],
"deepseek-v2": [
[160, 6, 5120, 12288],
],
"custom-small": [
[8, 2, 2048, 7168],
],
"glm45-fp8": [
[128, 8, 4096, 1408],
],
"Llama-4-Maverick-17B-128E-Instruct-FP8": [
[128, 1, 5120, 8192],
],
}
DEFAULT_MODELS = [
"mixtral-8x7b",
]
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
DEFAULT_TP_SIZES = [1]
PER_ACT_TOKEN_OPTS = [False, True]
PER_OUT_CH_OPTS = [False, True]
FP8_DTYPE = current_platform.fp8_dtype()
def bench_run(
results: list,
model: str,
num_experts: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
mkn: tuple[int, int, int],
):
(m, k, n) = mkn
dtype = torch.half
device = "cuda"
# Create input activations
a = torch.randn((m, k), device=device, dtype=dtype) / 10
# Create weights
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
# Create FP8 quantized weights and scales for both kernels
w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE)
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE)
# Create scales based on quantization strategy
if per_out_ch:
# Per-channel quantization
w1_scale = torch.empty(
(num_experts, 2 * n, 1), device=device, dtype=torch.float32
)
w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32)
else:
# Per-tensor quantization
w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
# Quantize weights
for expert in range(num_experts):
if per_out_ch:
# Per-channel quantization - not yet implemented properly
# For now, fall back to per-tensor quantization
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
# Expand scalar scales to the expected per-channel shape
w1_scale[expert] = w1_scale_temp.expand(2 * n, 1)
w2_scale[expert] = w2_scale_temp.expand(k, 1)
else:
# Per-tensor quantization
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
# Store scalar scales in [1, 1] tensors
w1_scale[expert, 0, 0] = w1_scale_temp
w2_scale[expert, 0, 0] = w2_scale_temp
# Prepare weights for CUTLASS (no transpose needed)
w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K]
w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N]
# Create router scores and get topk
score = torch.randn((m, num_experts), device=device, dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
# Force per-tensor quantization for all cases to match working e2e setup
a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
# Force per-tensor quantization for all cases
per_act_token = False
# Create stride tensors for CUTLASS
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)
for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config,
)
def run_cutlass_moe_fp8(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
cutlass_moe_fp8(
a=a,
w1_q=w1,
w2_q=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
activation="silu",
global_num_experts=num_experts,
)
# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
# Capture 10 invocations like benchmark_moe.py
for _ in range(10):
cutlass_moe_fp8(
a=a,
w1_q=w1_fp8q_cutlass,
w2_q=w2_fp8q_cutlass,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
activation="silu",
global_num_experts=num_experts,
)
torch.cuda.synchronize()
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
# Capture 10 invocations like benchmark_moe.py
for _ in range(10):
fused_experts(
a,
w1_fp8q,
w2_fp8q,
topk_weights,
topk_ids,
quant_config=quant_config,
)
torch.cuda.synchronize()
def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
"""Benchmark CUDA graph using events like benchmark_moe.py"""
# Warmup
for _ in range(num_warmup):
graph.replay()
torch.cuda.synchronize()
# Timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies = []
for _ in range(num_iters):
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
# Divide by 10 since graph contains 10 calls
return sum(latencies) / (num_iters * 10)
# Benchmark parameters
num_warmup = 5
num_iters = 100
# Benchmark only CUDA graphs (more reliable and faster)
# Benchmark Triton MoE with CUDA graphs
triton_graph_time = bench_cuda_graph(
triton_graph, num_warmup=num_warmup, num_iters=num_iters
)
# Benchmark CUTLASS MoE with CUDA graphs
cutlass_graph_time = bench_cuda_graph(
cutlass_graph, num_warmup=num_warmup, num_iters=num_iters
)
# Convert ms to us and return results
triton_time_us = triton_graph_time * 1000
cutlass_time_us = cutlass_graph_time * 1000
return {
"batch_size": m,
"triton_time_us": triton_time_us,
"cutlass_time_us": cutlass_time_us,
}
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
all_results = []
for model in args.models:
for tp in args.tp_sizes:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
topk = layer[1]
size_k = layer[2]
size_n = layer[3] // tp
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for per_act_token in args.per_act_token_opts:
for per_out_ch in args.per_out_ch_opts:
print(
f"\n=== {model}, experts={num_experts}, topk={topk},"
f"per_act={per_act_token}, per_out_ch={per_out_ch} ==="
)
config_results = []
for size_m in args.batch_sizes:
mkn = (size_m, size_k, size_n)
result = bench_run(
[], # Not used anymore
model,
num_experts,
topk,
per_act_token,
per_out_ch,
mkn,
)
if result:
config_results.append(result)
# Print results table for this configuration
if config_results:
print(
f"\n{'Batch Size':<12}"
f"{'Triton (us)':<15}"
f"{'CUTLASS (us)':<15}"
)
print("-" * 45)
for result in config_results:
print(
f"{result['batch_size']:<12}"
f"{result['triton_time_us']:<15.2f}"
f"{result['cutlass_time_us']:<15.2f}"
)
all_results.extend(config_results)
print(f"\nTotal benchmarks completed: {len(all_results)}")
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
across specified models/shapes/batches
Example usage:
python benchmark_cutlass_moe_fp8.py \
--model "Llama-4-Maverick-17B-128E-Instruct-FP8" \
--tp-sizes 8 \
--batch-size 2 4 8 \
--per-act-token-opts false \
--per-out-ch-opts false
"""
)
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument(
"--per-act-token-opts",
nargs="+",
type=lambda x: x.lower() == "true",
default=[False, True],
help="Per-activation token quantization options (true/false)",
)
parser.add_argument(
"--per-out-ch-opts",
nargs="+",
type=lambda x: x.lower() == "true",
default=[False, True],
help="Per-output channel quantization options (true/false)",
)
args = parser.parse_args()
main(args)

View File

@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")
"csrc/cpu/torch_bindings.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC

View File

@ -135,10 +135,10 @@ public:
max_splits = min(16, max_splits);
// TODO: This avoids a hang when the batch size larger than 1 and
// there is more than 4 kv_splits.
// there is more than 1 kv_splits.
// Discuss with NVIDIA how this can be fixed.
if (B > 1) {
max_splits = min(2, max_splits);
max_splits = min(1, max_splits);
}
// printf(" max_splits = %d\n", max_splits);

View File

@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
ops.def(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
"int group_size, bool apply_router_weight_on_input, int activation_kind"
") -> Tensor");
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// PagedAttention V2.
ops.def(
"paged_attention_v2("

View File

@ -0,0 +1,156 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/all.h>
// _dyn_quant_matmul_4bit is only available on AArch64.
#if defined(__aarch64__)
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
#endif
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
int64_t group_size_eff, int64_t in_features,
int64_t out_features) {
#if defined(__aarch64__)
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
in_features, out_features);
#else
TORCH_CHECK(false,
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
"_dyn_quant_matmul_4bit is unavailable on this architecture");
return {};
#endif
}
enum ActivationKind : int64_t {
SwiGLU_Gu = 0, // act = SiLU(g) * u
SwiGLUOAI = 1, // act = SiLU(u) * g
SiLU = 2 // SiLU
};
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
int64_t activation_kind) {
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
"topk tensors must be [T, K]");
TORCH_CHECK(
w13_packed.size(0) == w2_packed.size(0),
"w13_packed and w2_packed must have same number of experts in dim 0");
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
const int64_t T = x.size(0);
const int64_t K = topk_ids.size(1);
const int64_t E = w13_packed.size(0);
const int64_t N = T * K;
auto x_c = x.contiguous();
auto ids_c = topk_ids.contiguous();
auto gates_c = topk_weights.to(at::kFloat).contiguous();
// bucketing tokens -> experts
c10::SmallVector<int64_t, 64> counts(
E, 0); // Small vector uses stack allocation
{
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
for (int64_t i = 0; i < N; ++i) {
const int64_t e_id = ids_ptr[i];
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
counts[e_id]++;
}
}
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
{
c10::SmallVector<int64_t, 64> cursor(E, 0);
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
const auto* gts_ptr = gates_c.data_ptr<float>();
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
auto* gate_ptr = expert_gates.data_ptr<float>();
for (int64_t t = 0; t < T; ++t) {
const int64_t base = t * K;
for (int64_t k = 0; k < K; ++k) {
const int64_t idx = base + k;
const int64_t e = ids_ptr[idx];
const int64_t p = offsets[e] + (cursor[e]++);
tok_ptr[p] = t;
gate_ptr[p] = gts_ptr[idx];
}
}
}
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}
const int64_t start = offsets[e];
auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
// W13
auto y13 =
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
torch::Tensor act;
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
}
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}
// Store per-expert result
y_list[e] = y;
}
});
// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
return out;
}

View File

@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
int64_t activation_kind);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,

View File

@ -23,9 +23,14 @@
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
#if defined(HIP_FP8_TYPE_OCP)
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
#else
// ROCm 6.2 fallback: only *_fnuz types exist
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
#endif
#endif
#include "core/registration.h"

View File

@ -25,6 +25,12 @@
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent
#if !defined(HIP_FP8_TYPE_OCP)
using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz;
using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#endif
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__

View File

@ -34,7 +34,7 @@ Now supports 5 types of connectors:
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
```bash
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'
```
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):

View File

@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'`
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.

View File

@ -24,7 +24,7 @@ outlines_core == 0.2.11
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2
xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs

View File

@ -1079,7 +1079,7 @@ def dummy_llava_path():
local_dir=_dummy_llava_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
"*.msgpack", "*.safetensors"
])
assert os.path.exists(json_path)
with open(json_path) as f:
@ -1098,7 +1098,7 @@ def dummy_gemma2_embedding_path():
local_dir=_dummy_gemma2_embedding_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
"*.msgpack", "*.safetensors"
])
assert os.path.exists(json_path)
with open(json_path) as f:

View File

@ -382,7 +382,6 @@ def test_tp_language_generation(
test_options: PPTestOptions,
num_gpus_available,
):
pytest.skip("Skipping the test until V1 passes it.")
_compare_tp(model_id,
parallel_setup,
distributed_backend,
@ -410,7 +409,6 @@ def test_tp_language_embedding(
test_options: PPTestOptions,
num_gpus_available,
):
pytest.skip("Skipping the test until V1 passes it.")
_compare_tp(model_id,
parallel_setup,
distributed_backend,
@ -438,7 +436,6 @@ def test_tp_multimodal_generation(
test_options: PPTestOptions,
num_gpus_available,
):
pytest.skip("Skipping the test until V1 passes it.")
_compare_tp(model_id,
parallel_setup,
distributed_backend,

View File

@ -523,6 +523,7 @@ async def test_function_calling(client: OpenAI, model_name: str):
input="What's the weather like in Paris today?",
tools=tools,
temperature=0.0,
extra_body={"request_id": "test_function_calling_non_resp"},
)
assert response is not None
assert response.status == "completed"

View File

@ -194,6 +194,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
assert tc.function is not None and tc.function.name == "get_current_weather"
args1 = tc.function.arguments
assert args1 is not None and len(args1) > 0
assert not first_msg.content
messages.append({"role": "assistant", "content": args1})
messages.append({

View File

@ -85,8 +85,7 @@ def test_env(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size,
False)
backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "hip":
@ -106,7 +105,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -117,7 +115,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
@ -127,7 +124,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -136,7 +132,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -164,7 +159,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -179,7 +173,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
@ -199,7 +192,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1"
assert backend.get_name() == expected
@ -208,7 +200,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
@ -218,7 +209,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_MLA_VLLM_V1"
assert backend.get_name() == expected
@ -227,7 +217,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_VLLM_V1"
assert backend.get_name() == expected
@ -236,7 +225,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1"
assert backend.get_name() == expected
@ -245,7 +233,6 @@ def test_env(
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
@ -264,13 +251,13 @@ def test_fp32_fallback(
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16, False)
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(torch.cuda,
"get_device_capability",
lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
backend = get_attn_backend(16, torch.float16, "fp8", 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8, False)
backend = get_attn_backend(16, torch.float16, None, 8)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
import sys
original_module = sys.modules.get('vllm_flash_attn')
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
backend = get_attn_backend(16, torch.float16, None, 16, False)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Restore the original module if it existed
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, None, 16, True)
backend = get_attn_backend(17, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16, False)
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)

View File

@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)

View File

@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import InputProcessingContext
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of

View File

@ -42,7 +42,6 @@ def test_oot_registration_text_generation(
assert rest == ""
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
@create_new_process_for_each_test()
def test_oot_registration_embedding(
monkeypatch: pytest.MonkeyPatch,
@ -63,7 +62,6 @@ def test_oot_registration_embedding(
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
@create_new_process_for_each_test()
def test_oot_registration_multimodal(
monkeypatch: pytest.MonkeyPatch,

View File

@ -11,8 +11,9 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .registry import HF_EXAMPLE_MODELS
@ -264,7 +265,7 @@ def build_model_context(
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_cache_gb: int = 0,
):
"""Creates an InputContext for a given model.
"""Creates an InputProcessingContext for a given model.
Args:
model_id: ID of the model being considered.
@ -273,7 +274,7 @@ def build_model_context(
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
InputProcessingContext for the model being considered.
"""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
@ -298,7 +299,11 @@ def build_model_context(
enforce_eager=model_info.enforce_eager,
**model_config_kwargs,
)
return InputContext(model_config)
return InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
def check_embeddings_close(

View File

@ -8,11 +8,11 @@ import numpy as np
import pytest
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
from vllm.multimodal.processing import (InputProcessingContext,
PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
apply_token_matches,

View File

@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
# Create a concrete test implementation of BaseThinkingReasoningParser
class TestThinkingReasoningParser(BaseThinkingReasoningParser):
"""Test implementation of BaseThinkingReasoningParser."""
@property
def start_token(self) -> str:
return "<test:think>"
@property
def end_token(self) -> str:
return "</test:think>"
class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser):
"""Alternative test implementation with different tokens."""
@property
def start_token(self) -> str:
return "<alt:start>"
@property
def end_token(self) -> str:
return "<alt:end>"
# Use a test model
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def test_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# Add custom test tokens
test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"]
existing_tokens = set(tokenizer.get_vocab().keys())
new_tokens = [
token for token in test_tokens if token not in existing_tokens
]
if new_tokens:
tokenizer.add_tokens(new_tokens)
return tokenizer
class TestBaseThinkingReasoningParserInit:
"""
Test initialization and basic properties of
BaseThinkingReasoningParser.
"""
def test_successful_initialization(self, test_tokenizer):
"""Test successful initialization with valid tokens."""
parser = TestThinkingReasoningParser(test_tokenizer)
assert parser.start_token == "<test:think>"
assert parser.end_token == "</test:think>"
assert parser.start_token_id is not None
assert parser.end_token_id is not None
def test_initialization_with_missing_tokenizer(self):
"""Test that initialization fails without tokenizer."""
with pytest.raises(ValueError, match="model tokenizer must be passed"):
TestThinkingReasoningParser(None)
def test_initialization_with_missing_tokens(self, test_tokenizer):
"""Test that initialization fails when tokens are not in vocabulary."""
# Create a parser with tokens not in vocabulary
class MissingTokenParser(BaseThinkingReasoningParser):
@property
def start_token(self) -> str:
return "<missing:start>"
@property
def end_token(self) -> str:
return "<missing:end>"
with pytest.raises(RuntimeError,
match="could not locate think start/end tokens"):
MissingTokenParser(test_tokenizer)
def test_initialization_with_empty_tokens(self, test_tokenizer):
"""Test that initialization fails with empty token strings."""
class EmptyTokenParser(BaseThinkingReasoningParser):
@property
def start_token(self) -> str:
return ""
@property
def end_token(self) -> str:
return ""
with pytest.raises(ValueError,
match="start_token and end_token must be defined"):
EmptyTokenParser(test_tokenizer)
class TestBaseThinkingReasoningParserMethods:
"""Test the methods of BaseThinkingReasoningParser."""
def test_is_reasoning_end(self, test_tokenizer):
"""Test the is_reasoning_end method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
# Test with end token present
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
# Test without end token
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
# Test with empty list
assert parser.is_reasoning_end([]) is False
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
# Test with end token in the middle
input_ids = [1, 2, end_token_id, 4, 5]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == [4, 5]
# Test with end token at the end
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test without end token
input_ids = [1, 2, 3, 4]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test with end token as last element (should not extract)
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
class TestBaseThinkingReasoningParserExtraction:
"""Test reasoning content extraction methods."""
def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer):
"""Test extraction when both start and end tokens are present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("<test:think>This is reasoning"
"</test:think>This is content")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is reasoning"
assert content == "This is content"
def test_extract_reasoning_content_only_end_token(self, test_tokenizer):
"""Test extraction when only end token is present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("This is reasoning</test:think>This is content")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is reasoning"
assert content == "This is content"
def test_extract_reasoning_content_no_end_token(self, test_tokenizer):
"""Test extraction when no end token is present."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = "This is just content"
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == "This is just content"
assert content is None
def test_extract_reasoning_content_empty_output(self, test_tokenizer):
"""Test extraction with empty output."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ""
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == ""
assert content is None
def test_extract_reasoning_content_only_tokens(self, test_tokenizer):
"""Test extraction with only tokens and no content."""
parser = TestThinkingReasoningParser(test_tokenizer)
request = ChatCompletionRequest(messages=[], model="test-model")
model_output = ("<test:think></test:think>")
reasoning, content = parser.extract_reasoning_content(
model_output, request)
assert reasoning == ""
assert content is None
class TestBaseThinkingReasoningParserStreaming:
"""Test streaming functionality of BaseThinkingReasoningParser."""
@pytest.mark.parametrize("streaming", [True, False])
def test_simple_reasoning_extraction(self, test_tokenizer, streaming):
"""
Test basic reasoning extraction in both
streaming and non-streaming modes.
"""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = [
"<test:think>", "Some ", "reasoning ", "content", "</test:think>",
"Final ", "answer"
]
reasoning, content = run_reasoning_extraction(parser,
model_output,
streaming=streaming)
assert reasoning == "Some reasoning content"
assert content == "Final answer"
def test_streaming_with_incremental_deltas(self, test_tokenizer):
"""Test streaming processing with small incremental deltas."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning ",
"content",
"</test:think>",
"Final ",
"answer",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning content"
assert content == "Final answer"
def test_streaming_with_start_token(self, test_tokenizer):
"""Test streaming with start token included."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning",
"</test:think>",
"Answer",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning"
assert content == "Answer"
def test_streaming_no_end_token(self, test_tokenizer):
"""Test streaming when no end token is encountered."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Some ",
"reasoning ",
"without ",
"end",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning without end"
assert content is None
def test_streaming_only_end_token(self, test_tokenizer):
"""Test streaming when only end token appears."""
parser = TestThinkingReasoningParser(test_tokenizer)
deltas = [
"<test:think>",
"Reasoning ",
"content",
"</test:think>",
"Final",
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Reasoning content"
assert content == "Final"
class TestBaseThinkingReasoningParserMultipleImplementations:
"""
Test that multiple implementations of
BaseThinkingReasoningParser work correctly.
"""
def test_different_token_implementations(self, test_tokenizer):
"""
Test that different implementations
with different tokens work independently.
"""
parser1 = TestThinkingReasoningParser(test_tokenizer)
parser2 = TestThinkingReasoningParserAlt(test_tokenizer)
# Test parser1
model_output1 = ("Reasoning1</test:think>Content1")
reasoning1, content1 = run_reasoning_extraction(
parser1, [model_output1])
assert reasoning1 == "Reasoning1"
assert content1 == "Content1"
# Test parser2
model_output2 = "Reasoning2<alt:end>Content2"
reasoning2, content2 = run_reasoning_extraction(
parser2, [model_output2])
assert reasoning2 == "Reasoning2"
assert content2 == "Content2"
# Verify tokens are different
assert parser1.start_token != parser2.start_token
assert parser1.end_token != parser2.end_token
assert parser1.start_token_id != parser2.start_token_id
assert parser1.end_token_id != parser2.end_token_id
class TestBaseThinkingReasoningParserEdgeCases:
"""Test edge cases and error conditions."""
def test_multiple_end_tokens(self, test_tokenizer):
"""Test behavior with multiple end tokens."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("First</test:think>Middle</test:think>Last")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should stop at first end token
assert reasoning == "First"
assert content == "Middle</test:think>Last"
def test_nested_tokens(self, test_tokenizer):
"""Test behavior with nested-like token patterns."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("<test:think>Outer"
"<test:think>Inner</test:think>Content")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should process normally, start from first start token
assert reasoning == "Outer<test:think>Inner"
assert content == "Content"
def test_malformed_tokens(self, test_tokenizer):
"""Test behavior with malformed token-like strings."""
parser = TestThinkingReasoningParser(test_tokenizer)
model_output = ("<test:thinking>Not a real token"
"</test:thinking>Content")
reasoning, content = run_reasoning_extraction(parser, [model_output])
# Should treat as regular content since tokens don't match exactly
assert reasoning == ("<test:thinking>Not a real token"
"</test:thinking>Content")
assert content is None

View File

@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "seed_oss"
start_token = "<seed:think>"
end_token = "</seed:think>"
# Use a test model that contains our custom tokens
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def seedoss_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# Add custom SeedOSS tokens if they don't exist
if start_token not in tokenizer.get_vocab():
tokenizer.add_tokens([start_token, end_token])
return tokenizer
SIMPLE_REASONING: dict[str, Any] = {
"output": "This is a reasoning section</seed:think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING: dict[str, Any] = {
"output": "This is a reasoning section</seed:think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT: dict[str, Any] = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING: dict[str, Any] = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES: dict[str, Any] = {
"output": "This\nThat</seed:think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
WITH_START_TOKEN: dict[str, Any] = {
"output": ("<seed:think>This is a reasoning section"
"</seed:think>This is the rest"),
"reasoning_content":
"This is a reasoning section",
"content":
"This is the rest",
"is_reasoning_end":
True,
}
ONLY_END_TOKEN: dict[str, Any] = {
"output": "Some reasoning</seed:think>This is the rest",
"reasoning_content": "Some reasoning",
"content": "This is the rest",
"is_reasoning_end": True,
}
NO_TOKENS: dict[str, Any] = {
"output": "This is just content without any reasoning tokens",
"reasoning_content": "This is just content without any reasoning tokens",
"content": None,
"is_reasoning_end": False,
}
def test_seedoss_reasoning_parser_creation(seedoss_tokenizer):
"""Test that the SeedOSS reasoning parser can be created and registered."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
assert isinstance(parser, ReasoningParser)
assert parser.start_token == start_token
assert parser.end_token == end_token
@pytest.mark.parametrize("streaming", [True, False])
def test_simple_reasoning(seedoss_tokenizer, streaming):
"""Test basic reasoning extraction with both tokens."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming)
assert reasoning == SIMPLE_REASONING["reasoning_content"]
assert content == SIMPLE_REASONING["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_complete_reasoning(seedoss_tokenizer, streaming):
"""Test reasoning extraction when there's no content after reasoning."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming)
assert reasoning == COMPLETE_REASONING["reasoning_content"]
assert content == COMPLETE_REASONING["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_no_content(seedoss_tokenizer, streaming):
"""Test when there's no end token - everything is reasoning content."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, NO_CONTENT["output"])], streaming=streaming)
assert reasoning == NO_CONTENT["reasoning_content"]
assert content == NO_CONTENT["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_multiple_lines(seedoss_tokenizer, streaming):
"""Test reasoning extraction with multiline content."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming)
assert reasoning == MULTIPLE_LINES["reasoning_content"]
assert content == MULTIPLE_LINES["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_with_start_token(seedoss_tokenizer, streaming):
"""Test reasoning extraction with both start and end tokens."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming)
assert reasoning == WITH_START_TOKEN["reasoning_content"]
assert content == WITH_START_TOKEN["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_only_end_token(seedoss_tokenizer, streaming):
"""
Test reasoning extraction with only end token
(SeedOSS typical behavior).
"""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming)
assert reasoning == ONLY_END_TOKEN["reasoning_content"]
assert content == ONLY_END_TOKEN["content"]
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tokens(seedoss_tokenizer, streaming):
"""Test when there are no reasoning tokens at all."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
reasoning, content = run_reasoning_extraction(
parser, [cast(str, NO_TOKENS["output"])], streaming=streaming)
assert reasoning == NO_TOKENS["reasoning_content"]
assert content == NO_TOKENS["content"]
def test_is_reasoning_end(seedoss_tokenizer):
"""Test the is_reasoning_end method."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
# Test with end token present
end_token_id = parser.end_token_id
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
# Test without end token
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
def test_extract_content_ids(seedoss_tokenizer):
"""Test the extract_content_ids method."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
end_token_id = parser.end_token_id
# Test with end token in the middle
input_ids = [1, 2, end_token_id, 4, 5]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == [4, 5]
# Test with end token at the end
input_ids = [1, 2, 3, end_token_id]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
# Test without end token
input_ids = [1, 2, 3, 4]
content_ids = parser.extract_content_ids(input_ids)
assert content_ids == []
def test_streaming_delta_processing(seedoss_tokenizer):
"""Test streaming processing with small deltas."""
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
parser = parser_cls(seedoss_tokenizer)
# Test streaming with incremental tokens
deltas = [
"Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"
]
reasoning, content = run_reasoning_extraction(parser,
deltas,
streaming=True)
assert reasoning == "Some reasoning content"
assert content == "Final answer"

View File

@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch
import pytest
@ -388,3 +390,108 @@ def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
else:
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
assert actual_max_len == expected_max_len
class MockConfig:
"""Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
def __init__(self, model: str, tokenizer: str):
self.model = model
self.tokenizer = tokenizer
self.model_weights = None
@pytest.mark.parametrize("s3_url", [
"s3://example-bucket-1/model/",
"s3://example-bucket-2/model/",
])
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
"""Test that S3 URLs create deterministic local directories for model and
tokenizer."""
# Mock pull_files to avoid actually downloading files during tests
mock_pull_files.return_value = None
# Create first mock and run the method
config1 = MockConfig(model=s3_url, tokenizer=s3_url)
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
# Check that model and tokenizer point to existing directories
assert os.path.exists(
config1.model), f"Model directory does not exist: {config1.model}"
assert os.path.isdir(
config1.model), f"Model path is not a directory: {config1.model}"
assert os.path.exists(
config1.tokenizer
), f"Tokenizer directory does not exist: {config1.tokenizer}"
assert os.path.isdir(
config1.tokenizer
), f"Tokenizer path is not a directory: {config1.tokenizer}"
# Verify that the paths are different from the original S3 URL
assert config1.model != s3_url, (
"Model path should be converted to local directory")
assert config1.tokenizer != s3_url, (
"Tokenizer path should be converted to local directory")
# Store the original paths
created_model_dir = config1.model
create_tokenizer_dir = config1.tokenizer
# Create a new mock and run the method with the same S3 URL
config2 = MockConfig(model=s3_url, tokenizer=s3_url)
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
# Check that the new directories exist
assert os.path.exists(
config2.model), f"Model directory does not exist: {config2.model}"
assert os.path.isdir(
config2.model), f"Model path is not a directory: {config2.model}"
assert os.path.exists(
config2.tokenizer
), f"Tokenizer directory does not exist: {config2.tokenizer}"
assert os.path.isdir(
config2.tokenizer
), f"Tokenizer path is not a directory: {config2.tokenizer}"
# Verify that the paths are deterministic (same as before)
assert config2.model == created_model_dir, (
f"Model paths are not deterministic. "
f"Original: {created_model_dir}, New: {config2.model}")
assert config2.tokenizer == create_tokenizer_dir, (
f"Tokenizer paths are not deterministic. "
f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}")
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
def test_s3_url_different_models_create_different_directories(mock_pull_files):
"""Test that different S3 URLs create different local directories."""
# Mock pull_files to avoid actually downloading files during tests
mock_pull_files.return_value = None
s3_url1 = "s3://example-bucket-1/model/"
s3_url2 = "s3://example-bucket-2/model/"
# Create mocks with different S3 URLs and run the method
config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
# Verify that different URLs produce different directories
assert config1.model != config2.model, (
f"Different S3 URLs should create different model directories. "
f"URL1 model: {config1.model}, URL2 model: {config2.model}")
assert config1.tokenizer != config2.tokenizer, (
f"Different S3 URLs should create different tokenizer directories. "
f"URL1 tokenizer: {config1.tokenizer}, "
f"URL2 tokenizer: {config2.tokenizer}")
# Verify that both sets of directories exist
assert os.path.exists(config1.model) and os.path.isdir(config1.model)
assert os.path.exists(config1.tokenizer) and os.path.isdir(
config1.tokenizer)
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
assert os.path.exists(config2.tokenizer) and os.path.isdir(
config2.tokenizer)

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL = "deepseek-ai/DeepSeek-V3.1"
@pytest.fixture(scope="module")
def deepseekv31_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def parser(deepseekv31_tokenizer):
return DeepSeekV31ToolParser(deepseekv31_tokenizer)
def test_extract_tool_calls_with_tool(parser):
model_output = (
"normal text" + "<tool▁calls▁begin>" +
"<tool▁call▁begin>foo<tool▁sep>{\"x\":1}<tool▁call▁end>" +
"<tool▁calls▁end>")
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
assert result.content == "normal text"
def test_extract_tool_calls_with_multiple_tools(parser):
model_output = (
"some prefix text" + "<tool▁calls▁begin>" +
"<tool▁call▁begin>foo<tool▁sep>{\"x\":1}<tool▁call▁end>" +
"<tool▁call▁begin>bar<tool▁sep>{\"y\":2}<tool▁call▁end>" +
"<tool▁calls▁end>" + " some suffix text")
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
assert result.tool_calls[1].function.name == "bar"
assert result.tool_calls[1].function.arguments == "{\"y\":2}"
# prefix is content
assert result.content == "some prefix text"

View File

@ -70,7 +70,12 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
assert extracted_info.content == "This is a test"
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
@pytest.mark.parametrize("tool_args", [
'{"location": "Tokyo"}',
'{\n"location": "Tokyo"\n}',
])
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
tool_args):
convo = Conversation.from_messages([
Message.from_role_and_content(Role.USER,
"What is the weather in Tokyo?"),
@ -80,7 +85,7 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
tool_args).with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
@ -121,6 +126,17 @@ def test_extract_tool_calls_multiple_tools(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_user_location").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
"commentary").with_recipient("functions.no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
"commentary").with_recipient("functions.not_json_no_content_type"),
Message.from_role_and_content(
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
"functions.empty_args").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
"functions.no_args").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
@ -141,7 +157,63 @@ def test_extract_tool_calls_multiple_tools(
ToolCall(function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="no_content_type",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="not_json_no_content_type",
arguments="foo",
)),
ToolCall(function=FunctionCall(
name="empty_args",
arguments=json.dumps({}),
)),
ToolCall(function=FunctionCall(
name="no_args",
arguments="",
))
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
def test_extract_tool_calls_with_content(
openai_tool_parser,
harmony_encoding,
):
final_content = "This tool call will get the weather."
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT,
final_content).with_channel("final"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content == final_content

View File

@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
max_model_len=256,
max_seq_len_to_capture=256,
max_num_seqs=8,
tensor_parallel_size=tp,
enable_lora=True,

View File

@ -9,8 +9,9 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
split_attn_metadata,
split_decodes_and_prefills)
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
@pytest.fixture
@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
def apply_split_decodes_and_prefills(query_lens: list[int],
decode_threshold: int,
require_uniform: bool):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
device = torch.device("cpu")
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
common_metadata = create_common_attn_metadata(BatchSpec(
seq_lens=seq_lens, query_lens=query_lens),
block_size=16,
device=device)
return split_decodes_and_prefills(common_metadata,
decode_threshold=decode_threshold,
require_uniform=require_uniform)
def test_split_decodes_and_prefills_nonuniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, False))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
query_lens = [1, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 7
assert num_prefills == 0
assert num_decode_tokens == sum(query_lens)
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, False))
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, True))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_uniform_all_short_decodes():
query_lens = [2, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 2
assert num_prefills == 5
assert num_decode_tokens == 4
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
def test_split_decodes_and_prefills_uniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 6 # 2 + 2 + 2
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 1 # only the first 2 is taken as decode
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
assert num_decode_tokens == 2 # only the first 2
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[

View File

@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
get_block_hash, get_group_id,
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, get_block_hash,
get_group_id,
get_request_block_hasher,
hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata
parent_block_hash = None
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([6], )
assert blocks is not None and blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert blocks is not None and blocks.get_block_ids() == ([
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
], )
assert free_block_queue.num_free_blocks == 0
assert (free_block_queue.fake_free_list_head.next_free_block
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
5, 6, 7, 8
], [9, 10, 11, 12])
# Check full block metadata
parent_block_hash = None
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([13], [14], [15])
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
manager.block_pool.cached_block_hash_to_block._cache)
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == 3
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block[
manager.block_pool.cached_block_hash_to_block._cache[
hash_with_group_id] = cached_block_hash_to_block_bak[
hash_with_group_id]
manager.free(req)
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
# total cache miss.
# The cache hit length of full attention is 1 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit("8", [
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
@ -406,7 +412,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
@ -441,7 +447,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -478,6 +484,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
@ -513,7 +520,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -558,7 +565,8 @@ def test_evict():
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
# 5 full + 1 partial
assert blocks is not None and len(blocks.blocks[0]) == 6
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
@ -570,7 +578,7 @@ def test_evict():
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3 # 3 full blocks
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
@ -592,7 +600,7 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([10], )
assert blocks is not None and blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
# Deallocate the block.
manager.free(req)
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert manager.block_pool.blocks[blocks.blocks[0]
[0].block_id].block_hash is None
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 1
# Allocate another block.
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
# Free the blocks.
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3
assert blocks is not None and len(blocks.blocks[0]) == 3
# Free the blocks.
manager.free(req1)
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 4
assert blocks is not None and len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, sha256)
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
# Test that blocks that don't start from the beginning are cached
# correctly.
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
# Manually add all blocks to cached_blocks
for block, block_hash in zip(pool.blocks, block_hashes):
block.block_hash = block_hash
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
pool.cached_block_hash_to_block.insert(block_hash, block)
block0, block1, block2, block3 = pool.blocks
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
block3.block_id: block3,
},
block_hash1: {
block1.block_id: block1
},
block_hash2: {
block2.block_id: block2
}
block_hash1: block1,
block_hash2: block2,
}
# Evict block1
pool._maybe_evict_cached_block(block1)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block0: block_hash0 entry should NOT be removed, as block3
# also use the same hash
pool._maybe_evict_cached_block(block0)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block2
pool._maybe_evict_cached_block(block2)
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
# Evict block3
pool._maybe_evict_cached_block(block3)
assert pool.cached_block_hash_to_block == {}
assert pool.cached_block_hash_to_block._cache == {}
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
make_block_hash_with_group_id(block_hash_first_block, 0))
# New request
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
# there will be no matched prefix.
assert len(computed_blocks.blocks[0]) == 0
assert num_tokens == 0
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
key2 = BlockHashWithGroupId(b"hash2")
block0 = KVCacheBlock(0)
block1 = KVCacheBlock(1)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key0 inserted
cache.insert(key0, block0)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key1 inserted
cache.insert(key1, block1)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 100) is None
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key0, block ID 0)
assert cache.pop(key0, 0) is block0
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 1) is None
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key1, block ID 1)
assert cache.pop(key1, 1) is block1
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
def test_block_lookup_cache_multi_blocks_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
block00 = KVCacheBlock(0)
block01 = KVCacheBlock(1)
block10 = KVCacheBlock(10)
block11 = KVCacheBlock(11)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
cache.insert(key0, block00)
cache.insert(key0, block01)
cache.insert(key1, block10)
cache.insert(key1, block11)
assert cache.get_one_block(key0) is block00
assert cache.pop(key0, 0) is block00
assert cache.get_one_block(key0) is block01
assert cache.pop(key0, 1) is block01
assert cache.get_one_block(key0) is None
assert cache.pop(key0, 2) is None
assert cache.get_one_block(key1) is block10
assert cache.pop(key1, 10) is block10
assert cache.get_one_block(key1) is block11
assert cache.pop(key1, 11) is block11
assert cache.get_one_block(key1) is None
assert cache.pop(key1, 12) is None

View File

@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
block_pool.cached_block_hash_to_block._cache.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
block_pool.cached_block_hash_to_block.insert(
make_block_hash_with_group_id(block_hash, 0),
block_pool.blocks[i + 10])
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
block_pool.cached_block_hash_to_block._cache.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
block_pool.cached_block_hash_to_block.insert(
make_block_hash_with_group_id(block_hash, 0),
block_pool.blocks[i + 10])
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,

View File

@ -81,16 +81,6 @@ class CarDescription(BaseModel):
car_type: CarType
def _load_json(s: str, backend: str) -> str:
if backend != "xgrammar":
return json.loads(s)
# xgrammar specific workarounds
# https://github.com/mlc-ai/xgrammar/issues/286
s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s)
return json.loads(s)
def test_guided_decoding_deprecated():
with pytest.warns(DeprecationWarning,
match="GuidedDecodingParams is deprecated.*"):
@ -177,7 +167,12 @@ def test_structured_output(
if backend != 'lm-format-enforcer':
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
try:
output_json = json.loads(generated_text)
except json.JSONDecodeError as e:
pytest.fail(
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
f"Schema: {sample_json_schema}\nError: {e}")
jsonschema.validate(instance=output_json, schema=sample_json_schema)
#
@ -425,7 +420,12 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
try:
output_json = json.loads(generated_text)
except json.JSONDecodeError as e:
pytest.fail(
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
f"Schema: {json_schema}\nError: {e}")
jsonschema.validate(instance=output_json, schema=json_schema)
#
@ -468,7 +468,12 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
try:
output_json = json.loads(generated_text)
except json.JSONDecodeError as e:
pytest.fail(
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
f"Schema: {json_schema}\nError: {e}")
jsonschema.validate(instance=output_json, schema=json_schema)
if backend not in ["outlines", "lm-format-enforcer"]:

View File

@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
"or, being injured, not kill, except in",
"without the heart, one can only see wrongly.",
"but in rising every time we fall. - Nelson"
"or kill a human being",
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:

View File

@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp
import os
import tempfile
from multiprocessing import Queue
from typing import Optional
from unittest.mock import patch
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import MemorySnapshot
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
# Global queue to track operation order across processes
_QUEUE: Optional[Queue] = None
def track_operation(operation: str, rank: int):
"""Track when an operation happens and its rank."""
if _QUEUE is not None:
_QUEUE.put((operation, rank))
def make_operation_tracker(operation_name: str, original_func):
"""Create a mock function that tracks when an operation is called.
Args:
operation_name: Name to use when tracking this operation
original_func: The original function to wrap
Returns:
A wrapper function that tracks the operation and calls the original
"""
def wrapper(*args, **kwargs):
rank = int(os.environ.get("RANK", "-1"))
track_operation(operation_name, rank)
return original_func(*args, **kwargs)
return wrapper
def worker_process(rank: int, world_size: int, distributed_init_method: str,
queue: Queue, error_queue: Queue):
"""Worker process that initializes a GPU worker with proper tracking."""
global _QUEUE
_QUEUE = queue
try:
# Set environment variables
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Create vLLM config with small model
vllm_config = EngineArgs(model="facebook/opt-125m",
tensor_parallel_size=2,
load_format="dummy").create_engine_config()
# Create worker
worker = Worker(
vllm_config=vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
)
# Get original functions before patching
original_init_worker = init_worker_distributed_environment
original_memory_snapshot_init = MemorySnapshot.__init__
original_all_reduce = torch.distributed.all_reduce
# Apply minimal patches to track operation order
init_patch = patch(
'vllm.v1.worker.gpu_worker.init_worker_distributed_environment',
side_effect=make_operation_tracker("init_distributed",
original_init_worker))
memory_patch = patch.object(
MemorySnapshot, '__init__',
make_operation_tracker("memory_snapshot",
original_memory_snapshot_init))
all_reduce_patch = patch('torch.distributed.all_reduce',
side_effect=make_operation_tracker(
"nccl_all_reduce", original_all_reduce))
with init_patch, memory_patch, all_reduce_patch:
# Initialize device (this is where we test the order)
worker.init_device()
# Load model to ensure everything works
worker.load_model()
# Signal success
queue.put(("success", rank))
except Exception as e:
error_queue.put((rank, str(e), type(e).__name__))
raise
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs for tensor parallelism")
def test_init_distributed_is_called_before_memory_snapshot():
"""Test that distributed env is setup before memory snapshot.
This test makes sure during worker initialization, the initial memory
snapshot is taken after distributed env is setup to include all the buffers
allocated by distributed env.
"""
world_size = 2
# Create a temporary file for distributed init
with tempfile.NamedTemporaryFile(delete=False) as f:
distributed_init_method = f"file://{f.name}"
# Create queues for inter-process communication
ctx = mp.get_context("spawn")
operation_queue = ctx.Queue()
error_queue = ctx.Queue()
# Start worker processes
processes = []
for rank in range(world_size):
p = ctx.Process(target=worker_process,
args=(rank, world_size, distributed_init_method,
operation_queue, error_queue))
p.start()
processes.append(p)
# Wait for all processes to complete
for p in processes:
p.join(timeout=60) # 60 second timeout
# Check for errors
errors = []
while not error_queue.empty():
rank, error_msg, error_type = error_queue.get()
errors.append(f"Rank {rank}: {error_type}: {error_msg}")
if errors:
pytest.fail("Worker processes failed:\n" + "\n".join(errors))
# Collect all operations from the queue
operations = []
while not operation_queue.empty():
operations.append(operation_queue.get())
# Verify we got operations from both ranks
print(f"Collected operations: {operations}")
# Check operations for each rank
for rank in range(world_size):
rank_ops = [op for op, r in operations if r == rank]
print(f"\nRank {rank} operations: {rank_ops}")
# Raises ValueError if the operation is not found
init_distributed = rank_ops.index("init_distributed")
nccl_all_reduce = rank_ops.index("nccl_all_reduce")
memory_snapshot = rank_ops.index("memory_snapshot")
# Verify order: init_distributed should happen before memory_snapshot
assert init_distributed < nccl_all_reduce < memory_snapshot, (
f"Rank {rank}: init_distributed (index {init_distributed}) "
f"must happen before nccl_all_reduce (index {nccl_all_reduce}) "
f"and memory_snapshot (index {memory_snapshot})")
# Clean up
os.unlink(distributed_init_method.replace("file://", ""))

View File

@ -1,314 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from itertools import accumulate
from typing import List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
# lack attention.
class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""
@staticmethod
def get_name() -> str:
return "NO_ATTENTION"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl
@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder
@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return
@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Placeholder.
block_tables: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder):
self.input_builder = input_builder
self.runner = input_builder.runner
def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
# Some input builders such as ModelInputForCPUBuilder do not have the
# "inter_data_list" attribute.
# Let's check inter_data_list exists before we reference it.
if hasattr(self.input_builder, "inter_data_list"):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph:
num_decode_tokens = batch_size - self.num_prefill_tokens
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
# Placeholders
slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0)
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class PlaceholderAttentionImpl(AttentionImpl):
def __init__(self, *args, **kwargs) -> None:
return
def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError

View File

@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
max_decode_seq_len=self.runner.max_model_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
attn_metadata.num_encoder_tokens = 0
def _add_additional_input_buffers_for_enc_dec_model(

View File

@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase):
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla,
has_sink=self.has_sink)
else:
@ -578,9 +575,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)
@ -631,6 +626,5 @@ direct_register_custom_op(
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)

View File

@ -2,10 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@triton.jit

View File

@ -142,7 +142,6 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
@ -156,7 +155,6 @@ def get_attn_backend(
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
@ -169,17 +167,10 @@ def _cached_get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
# Check whether a particular choice of backend was
# previously forced.

View File

@ -547,7 +547,6 @@ if flashinfer_comm is not None:
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)

View File

@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
class EagerAdaptor(CompilerInterface):

View File

@ -8,6 +8,7 @@ from unittest.mock import patch
import torch
import torch.nn as nn
from packaging import version
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from vllm.compilation.counter import compilation_counter
@ -300,13 +301,13 @@ def _support_torch_compile(
logger.debug(
"enable_cpp_symbolic_shape_guards config not available")
with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper(
self.vllm_config):
with patch.object(
InliningInstructionTranslator, "inline_call",
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper(
self.vllm_config), _torch27_patch_tensor_subclasses():
output = self.compiled_callable(*args, **kwargs)
return output
# usually, capturing the model once is enough, and then we can
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None)
@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
`BasevLLMParameters` without having to replace them with regular tensors
before `torch.compile`-time.
"""
from vllm.model_executor.parameter import (BasevLLMParameter,
ModelWeightParameter,
RowvLLMParameter,
_ColumnvLLMParameter)
def return_false(*args, **kwargs):
return False
if version.parse("2.7") <= version.parse(
torch.__version__) < version.parse("2.8"):
yield
return
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
RowvLLMParameter
]),
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
return_false)):
yield

View File

@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.device import Device, DeviceConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
try_match_architecture_defaults)
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
from vllm.logger import init_logger
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
...
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)
DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
show_hidden_metrics_for_version: Optional[str] = None
"""Enable deprecated Prometheus metrics that have been hidden since the
specified version. For example, if a previously deprecated metric has been
hidden since the v0.7.0 release, you use
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
you migrate to new metrics. The metric is likely to be removed completely
in an upcoming release."""
@cached_property
def show_hidden_metrics(self) -> bool:
"""Check if the hidden metrics should be shown."""
if self.show_hidden_metrics_for_version is None:
return False
return version._prev_minor_version_was(
self.show_hidden_metrics_for_version)
otlp_traces_endpoint: Optional[str] = None
"""Target URL to which OpenTelemetry traces will be sent."""
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
set, it will collect detailed traces for the specified modules. This
involves use of possibly costly and or blocking operations and hence might
have a performance impact.
Note that collecting detailed timing information for each request can be
expensive."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
return (self.collect_detailed_traces is not None
and ("model" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
@cached_property
def collect_model_execute_time(self) -> bool:
"""Whether to collect model execute time for the request."""
return (self.collect_detailed_traces is not None
and ("worker" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if (self.collect_detailed_traces is not None
and len(self.collect_detailed_traces) == 1
and "," in self.collect_detailed_traces[0]):
self._parse_collect_detailed_traces()
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}")
def _parse_collect_detailed_traces(self):
assert isinstance(self.collect_detailed_traces, list)
self.collect_detailed_traces = cast(
list[DetailedTraceModules],
self.collect_detailed_traces[0].split(","))
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
}
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}

74
vllm/config/device.py Normal file
View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal, Optional, Union
import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)

View File

@ -177,11 +177,6 @@ class ModelConfig:
graph and always execute the model in eager mode. If False, we will use
CUDA graph and eager execution in hybrid for maximal performance and
flexibility."""
max_seq_len_to_capture: int = 8192
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
length larger than this, we fall back to eager mode. Additionally for
encoder-decoder models, if the sequence length of the encoder input is
larger than this, we fall back to the eager mode."""
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
@ -699,11 +694,12 @@ class ModelConfig:
model: Model name or path
tokenizer: Tokenizer name or path
"""
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
return
if is_runai_obj_uri(model):
object_storage_model = ObjectStorageModel()
object_storage_model = ObjectStorageModel(url=model)
object_storage_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = model
@ -722,7 +718,7 @@ class ModelConfig:
# Only download tokenizer if needed and not already handled
if is_runai_obj_uri(tokenizer):
object_storage_tokenizer = ObjectStorageModel()
object_storage_tokenizer = ObjectStorageModel(url=tokenizer)
object_storage_tokenizer.pull_files(model,
ignore_pattern=[
"*.pt", "*.safetensors",
@ -1023,21 +1019,8 @@ class ModelConfig:
current_platform.verify_quantization(self.quantization)
def _verify_cuda_graph(self) -> None:
# The `max_seq_len_to_capture` was incorrectly
# based on the encoder's input length (448)
# but not the decoder's larger input length (1500).
# This change ensures the CUDA Graph captures the correct,
# larger sequence length, allowing it to work as intended.
effective_max_seq_len = self.max_model_len
if self.is_encoder_decoder:
effective_max_seq_len = max(
effective_max_seq_len,
getattr(self.hf_config, "max_source_positions", 0))
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
effective_max_seq_len)
# CUDAGraph capture not supported for encoder-decoder models on ROCm
unsupported_rocm = self.is_encoder_decoder
if (unsupported_rocm and not self.enforce_eager
and current_platform.is_rocm()):
logger.warning(

View File

@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from functools import cached_property
from typing import Any, Literal, Optional, cast
from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
show_hidden_metrics_for_version: Optional[str] = None
"""Enable deprecated Prometheus metrics that have been hidden since the
specified version. For example, if a previously deprecated metric has been
hidden since the v0.7.0 release, you use
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
you migrate to new metrics. The metric is likely to be removed completely
in an upcoming release."""
@cached_property
def show_hidden_metrics(self) -> bool:
"""Check if the hidden metrics should be shown."""
if self.show_hidden_metrics_for_version is None:
return False
return version._prev_minor_version_was(
self.show_hidden_metrics_for_version)
otlp_traces_endpoint: Optional[str] = None
"""Target URL to which OpenTelemetry traces will be sent."""
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
set, it will collect detailed traces for the specified modules. This
involves use of possibly costly and or blocking operations and hence might
have a performance impact.
Note that collecting detailed timing information for each request can be
expensive."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
return (self.collect_detailed_traces is not None
and ("model" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
@cached_property
def collect_model_execute_time(self) -> bool:
"""Whether to collect model execute time for the request."""
return (self.collect_detailed_traces is not None
and ("worker" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if (self.collect_detailed_traces is not None
and len(self.collect_detailed_traces) == 1
and "," in self.collect_detailed_traces[0]):
self._parse_collect_detailed_traces()
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}")
def _parse_collect_detailed_traces(self):
assert isinstance(self.collect_detailed_traces, list)
self.collect_detailed_traces = cast(
list[DetailedTraceModules],
self.collect_detailed_traces[0].split(","))

View File

@ -285,8 +285,6 @@ class SpeculativeConfig:
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None

View File

@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake,
)

View File

@ -392,7 +392,8 @@ class MessageQueue:
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug(
("No available shared memory broadcast block found"
" in %s second."),
" in %s seconds. This typically happens when some"
" processes are hanging."),
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1
@ -455,7 +456,8 @@ class MessageQueue:
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug(
("No available shared memory broadcast block found"
" in %s second."),
" in %s seconds. This typically happens when some"
" processes are hanging."),
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1

View File

@ -574,7 +574,6 @@ class NixlConnectorWorker:
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.use_mla)
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)

View File

@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -373,7 +373,6 @@ class EngineArgs:
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
@ -545,8 +544,6 @@ class EngineArgs:
**model_kwargs["quantization"])
model_group.add_argument("--enforce-eager",
**model_kwargs["enforce_eager"])
model_group.add_argument("--max-seq-len-to-capture",
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
@ -1008,7 +1005,6 @@ class EngineArgs:
max_model_len=self.max_model_len,
quantization=self.quantization,
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,

View File

@ -130,11 +130,6 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
hf_token: The token to use as HTTP bearer authorization for remote files
@ -184,7 +179,6 @@ class LLM:
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
@ -281,7 +275,6 @@ class LLM:
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
hf_token=hf_token,
hf_overrides=hf_overrides,

View File

@ -1186,6 +1186,10 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
if self.use_harmony:
reasoning_content, content, _ = parse_chat_output(token_ids)
if not request.include_reasoning:
reasoning_content = None
if self.tool_parser is not None:
tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser
@ -1194,10 +1198,7 @@ class OpenAIServingChat(OpenAIServing):
request=request,
token_ids=token_ids, # type: ignore
)
reasoning_content, content = None, tool_call_info.content
if request.include_reasoning:
reasoning_content, content, _ = parse_chat_output(
token_ids)
content = tool_call_info.content
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
@ -1205,10 +1206,6 @@ class OpenAIServingChat(OpenAIServing):
tool_calls=tool_call_info.tool_calls,
)
else:
reasoning_content, content, _ = parse_chat_output(
token_ids)
if not request.include_reasoning:
reasoning_content = None
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,

View File

@ -235,8 +235,6 @@ class OpenAIServingResponses(OpenAIServing):
# Handle the previous response ID.
prev_response_id = request.previous_response_id
if prev_response_id is not None:
if not prev_response_id.startswith("resp_"):
return self._make_invalid_id_error(prev_response_id)
async with self.response_store_lock:
prev_response = self.response_store.get(prev_response_id)
if prev_response is None:
@ -924,9 +922,6 @@ class OpenAIServingResponses(OpenAIServing):
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
StreamingResponsesResponse, None]]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
@ -944,9 +939,6 @@ class OpenAIServingResponses(OpenAIServing):
self,
response_id: str,
) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
if response is None:
@ -972,13 +964,6 @@ class OpenAIServingResponses(OpenAIServing):
response_id)
return response
def _make_invalid_id_error(self, response_id: str) -> ErrorResponse:
return self.create_error_response(
err_type="invalid_request_error",
message=(f"Invalid 'response_id': '{response_id}'. "
"Expected an ID that begins with 'resp'."),
)
def _make_not_found_error(self, response_id: str) -> ErrorResponse:
return self.create_error_response(
err_type="invalid_request_error",

View File

@ -39,7 +39,7 @@ class DeepSeekV31ToolParser(ToolParser):
self.tool_call_end_token: str = "<tool▁call▁end>"
self.tool_call_regex = re.compile(
r"<tool▁call▁begin>(?P<function_name>.*)<tool▁sep>(?P<function_arguments>.*)<tool▁call▁end>"
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>"
)
self.stream_tool_call_portion_regex = re.compile(

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import TYPE_CHECKING
@ -12,10 +13,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("openai")
class OpenAIToolParser(ToolParser):
@ -40,17 +44,33 @@ class OpenAIToolParser(ToolParser):
if len(parser.messages) > 0:
for msg in parser.messages:
if len(msg.content) < 1:
continue
msg_text = msg.content[0].text
if msg.recipient and msg.recipient.startswith("functions."):
# If no content-type is given assume JSON, as that's the
# most common case with gpt-oss models.
if not msg.content_type or "json" in msg.content_type:
# load and dump the JSON text to check validity and
# remove any extra newlines or other odd formatting
try:
tool_args = json.dumps(json.loads(msg_text))
except json.JSONDecodeError:
logger.exception(
"Error decoding JSON tool call from response.")
tool_args = msg_text
else:
tool_args = msg_text
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=msg.content[0].text,
arguments=tool_args,
),
))
elif msg.channel == "final":
final_content = msg.content[0].text
final_content = msg_text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,

View File

@ -119,7 +119,7 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
@ -187,12 +187,15 @@ if TYPE_CHECKING:
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
@ -1014,7 +1017,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# max number splits for cuda graph decode
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"16")),
"32")),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
@ -1385,6 +1388,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
# Add optional nvtx scopes for profiling, disable to avoid overheads
"VLLM_NVTX_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))),
# Represent block hashes in KV cache events as 64-bit integers instead of
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
@ -1413,6 +1420,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
"code_interpreter",
"web_search_preview"]),
# Enable max_autotune & coordinate_descent_tuning in inductor_config
# to compile static shapes passed from compile_sizes in compilation_config
# If set to 1, enable max_autotune; By default, this is enabled (1)
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE":
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))),
# If set to 1, enable coordinate_descent_tuning;
# By default, this is enabled (1)
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING":
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"1"))),
# Flag to enable NCCL symmetric memory allocation and registration
"VLLM_USE_NCCL_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
@ -1513,6 +1531,8 @@ def compute_hash() -> str:
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
]
for key in environment_variables_to_hash:
# if this goes out of sync with environment_variables,

View File

@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputProcessingContext
__all__ = [
"DataPrompt",
@ -28,6 +27,4 @@ __all__ = [
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"InputContext",
"InputProcessingContext",
]

View File

@ -1,186 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Union
import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
ModelConfig = Any
AnyTokenizer = Any
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
logger = init_logger(__name__)
@dataclass(frozen=True)
class InputContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
def get_hf_config(
self,
typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
/,
) -> _C:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs)
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(
self,
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> _P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
) -> Union[BatchFeature, JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
def maybe_cast_dtype(x):
# This mimics the behavior of transformers.BatchFeature
if isinstance(x, torch.Tensor) and x.is_floating_point():
return x.to(dtype=self.model_config.dtype)
return x
try:
output = hf_processor(**data,
**allowed_kwargs,
return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output)
cast_output = json_map_leaves(maybe_cast_dtype, output)
logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.")
return cast_output
except Exception as exc:
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}")
raise ValueError(msg) from exc

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -283,7 +282,6 @@ try:
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_expand = torch.ops.vllm.lora_expand

View File

@ -11,7 +11,6 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@ -237,7 +236,6 @@ try:
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
dispatch_key=current_platform.dispatch_key,
)
lora_shrink = torch.ops.vllm.lora_shrink

View File

@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.constexpr, # num of sequences
T: tl.constexpr, # num of tokens
N: tl.int64, # num of sequences
T: tl.int64, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,

View File

@ -0,0 +1,123 @@
{
"triton_version": "3.4.0",
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -98,13 +98,16 @@ def select_experts(
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_logit_vals, topk_idx = torch.topk(router_logits,
k=top_k,
dim=-1,
sorted=False)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
else:
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
topk_vals = (topk_logit_vals - logZ).exp()
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
else:
return custom_routing_function(hidden_states=hidden_states,
gating_output=router_logits,

View File

@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
mutates_args=[],
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
fake_impl=fused_marlin_moe_fake,
)

View File

@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
tags=(() if is_torch_equal_or_newer("2.7.0") else
(torch.Tag.needs_fixed_stride_order, )),

View File

@ -23,7 +23,7 @@ if has_triton_kernels():
from triton_kernels.routing import (RoutingData, routing,
routing_from_bitmatrix)
from triton_kernels.tensor import Bitmatrix
except (ModuleNotFoundError, AttributeError) as e:
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s", e)

View File

@ -69,8 +69,6 @@ else:
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk)
elif current_platform.is_cpu():
pass
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
@ -2040,7 +2038,6 @@ direct_register_custom_op(
op_func=moe_forward,
mutates_args=["hidden_states"],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
@ -2071,7 +2068,6 @@ direct_register_custom_op(
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -223,17 +223,13 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -241,7 +237,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -249,7 +244,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
@ -257,7 +251,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -103,17 +103,13 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
ModelWeightParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"UnquantizedLinearMethod",
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"BitBLASLinearMethod",
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
try:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
weight_loader = extra_weight_attrs.pop("weight_loader")
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
except torch.cuda.OutOfMemoryError as e:
logger.error("Failed to create unquantized linear weights: %s", e)
if torch.cuda.is_available():
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
"Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate "
"the weight.") from e
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

View File

@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
@ -401,5 +400,4 @@ direct_register_custom_op(
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@ -464,5 +463,4 @@ direct_register_custom_op(
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@ -765,5 +764,4 @@ direct_register_custom_op(
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -17,7 +17,6 @@ from .mamba_ssm import softplus
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_H': 1}),
triton.Config({'BLOCK_SIZE_H': 2}),
triton.Config({'BLOCK_SIZE_H': 4}),
triton.Config({'BLOCK_SIZE_H': 8}),

View File

@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
@ -251,5 +250,4 @@ direct_register_custom_op(
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@ -63,7 +64,7 @@ __all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
]
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
- Bias: Same data type as original weights
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
quantized inside the kernel
"""
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.quant_config = quant_config
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
aq = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
# Must be dynamic per-token activations
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = wq.group_size if (wq.group_size is not None) else -1
if wq.num_bits != 4:
raise ValueError(
"This method only supports 4-bit weights (num_bits=4).")
# CPU only
if not current_platform.is_cpu():
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
# Arm: check _dyn ops availability
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
_ = torch.ops.aten._dyn_quant_matmul_4bit
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
except AttributeError as err:
raise RuntimeError(
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
install a newer build.""") from err
self.static_input_scales = False # always dynamic per token
# ---- parameter creation ----
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Shapes per local rank (TP/EP):
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
# w2 : [E, H, I_local] int8
# Scales:
# channel-wise: group_size=-1 -> per-output-row, single scale per row
# group-wise : group_size=g ->
# per-output-row, (in_features/g) scales
E = num_experts
H = hidden_size
IN = intermediate_size_per_partition
g = self.group_size
# Per-row scale columns
def _n_scale_cols(in_features: int) -> int:
return 1 if g == -1 else (in_features // g)
# Register unpacked int4-as-int8 weights the loader will fill.
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
requires_grad=False)
set_weight_attrs(w13, extra_weight_attrs)
layer.register_parameter("w13_weight", w13)
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
requires_grad=False)
set_weight_attrs(w2, extra_weight_attrs)
layer.register_parameter("w2_weight", w2)
# Register scales
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
w13_s = torch.nn.Parameter(torch.ones(E,
2 * IN,
_n_scale_cols(H),
dtype=scale_dtype),
requires_grad=False)
set_weight_attrs(
w13_s, {
"quant_method": "channel" if g == -1 else "group",
**extra_weight_attrs
})
layer.register_parameter("w13_weight_scale", w13_s)
w2_s = torch.nn.Parameter(torch.ones(E,
H,
_n_scale_cols(IN),
dtype=scale_dtype),
requires_grad=False)
set_weight_attrs(
w2_s, {
"quant_method": "channel" if g == -1 else "group",
**extra_weight_attrs
})
layer.register_parameter("w2_weight_scale", w2_s)
if self.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(E,
2 * IN,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
# Placeholders for packed weights (will be replaced after packing)
layer.register_parameter(
"w13_weight_packed",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
layer.register_parameter(
"w2_weight_packed",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
# dims for 4 bit fused matmuls
layer.w13_in_features = H
layer.w13_out_features = 2 * IN
layer.w2_in_features = IN
layer.w2_out_features = H
layer.group_size = g
# post-load packing to dyn-4bit KleidiAI kernel's format
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E = layer.w13_weight.shape[0]
H = layer.w13_in_features
I2 = layer.w13_out_features
IN = layer.w2_in_features
g = layer.group_size
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
scales_2d: torch.Tensor,
bias_1d: Optional[torch.Tensor], in_features: int,
out_features: int) -> torch.Tensor:
# int4 values are stored as int8 in [-8,7].
# Shift to unsigned nibble and pack pairs along input-dim.
tmp = int4_as_int8_2d.add(8) # [out, in]
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
torch.uint8) # [out, in//2]
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
scales = scales_2d.to(scale_dtype)
bias = None if bias_1d is None else bias_1d.to(torch.float32)
return torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_nibbles, scales, bias, g if g != -1 else in_features,
in_features, out_features)
# Pack per expert
w13_packed_list = []
w2_packed_list = []
has_w13_bias = hasattr(layer,
"w13_bias") and layer.w13_bias is not None
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
for e in range(E):
w13_packed_list.append(
_pack_matrix(
layer.w13_weight[e], # [2I, H]
layer.w13_weight_scale[e], # [2I, H/g or 1]
layer.w13_bias[e] if has_w13_bias else None, # [2I]
H,
I2))
w2_packed_list.append(
_pack_matrix(
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
layer.w2_weight[e], # [H, IN]
layer.w2_weight_scale[e], # [H, IN/g or 1]
layer.w2_bias[e] if has_w2_bias else None, # [H]
IN,
layer.w2_out_features # in_features=IN, out_features=H
))
# each packed tensor has identical shape per expert; stack on dim 0
w13_packed = torch.stack(w13_packed_list, dim=0)
w2_packed = torch.stack(w2_packed_list, dim=0)
replace_parameter(layer, "w13_weight_packed",
torch.nn.Parameter(w13_packed, requires_grad=False))
replace_parameter(layer, "w2_weight_packed",
torch.nn.Parameter(w2_packed, requires_grad=False))
# free raw tensors/scales/bias now that they're packed into the payload.
replace_parameter(
layer, "w13_weight",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w2_weight",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w13_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w2_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
if has_w13_bias:
replace_parameter(
layer, "w13_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
if has_w2_bias:
replace_parameter(
layer, "w2_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
# CPU dynamic 4-bit MoE path does not use modular kernels or
# fused_experts; quant config is not needed.
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in (
"silu", "swigluoai",
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if s == "swiglu":
return 0
if s == "swigluoai":
return 1
if s == "silu":
return 2
raise ValueError(f"Unknown activation '{s}'")
# Apply topk softmax on router output
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops._C.dynamic_4bit_int_moe(
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_out_features,
layer.w2_in_features, layer.w13_out_features, layer.group_size,
apply_router_weight_on_input, int(_act_kind(activation)))

View File

@ -4,7 +4,6 @@ import logging
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_gemm_nt
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[],
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -161,7 +161,6 @@ try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
@ -273,7 +272,6 @@ try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
@ -319,7 +317,6 @@ try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

View File

@ -51,9 +51,7 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm() or (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)

View File

@ -91,9 +91,7 @@ if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
@ -132,13 +130,14 @@ def _w8a8_triton_block_scaled_mm_fake(
device=qx.device)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
# Note: the check can be removed when CPU torch > 2.7
if not current_platform.is_cpu():
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
# TODO fix ROCm->Triton custom path:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from typing import Any, Callable, Optional
import torch
@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
value_layout_opts: dict[str, Any] = {}
scale_layout_opts: dict[str, Any] = {}
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
value_layout = StridedLayout
scale_layout = StridedLayout
elif current_platform.is_rocm():
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
StridedLayout)
from vllm.platforms.rocm import on_gfx950
value_layout = StridedLayout
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
@ -113,7 +124,6 @@ try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
@ -124,7 +134,6 @@ try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4

View File

@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -147,5 +147,4 @@ direct_register_custom_op(
op_func=_flashinfer_rotary_embedding,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union, cast
from typing import Annotated, Literal, Optional, Union
import torch
from torch import nn
@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
**kwargs)
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
target_dtype: torch.dtype = \
vision_tower.get_input_embeddings().weight.dtype
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
return json_map_leaves(select_features, image_features)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:

View File

@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config
config.max_seq_len_to_capture = config.max_model_len
logger.info(
"Setting max_seq_len_to_capture to %d "
"to ensure that CUDA graph capture "
"covers sequences of length up to max_model_len.",
config.max_model_len)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"XLMRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,

View File

@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -26,6 +26,7 @@ from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
GeluAndMul,
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index,
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
logger = init_logger(__name__)
EPS = torch.tensor(torch.finfo().min)
class Gemma3nAltUp(nn.Module):
"""Alternating updates (Altup)
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
return corrected_predictions
@support_torch_compile
class Gemma3nTextModel(nn.Module, SupportsQuant):
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nSelfDecoder(nn.Module):
"""
Includes altup embedding and self decoder layers
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
prefix=f"{prefix}.altup_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
self.altup_unembed_projections = nn.ModuleList([
ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
# Transformer blocks.
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3nDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.eps = torch.tensor(torch.finfo().min)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return self.embed_tokens_per_layer(
per_layer_inputs_tokens) * self.embed_scale_per_layer
def forward(
def get_per_layer_inputs(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
hidden_states_0: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor],
) -> torch.Tensor:
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
return per_layer_inputs
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
hidden_states = torch.stack(hidden_states, dim=0)
new_magnitude, EPS)
hidden_states = torch.stack(hidden_states, dim=-1)
return hidden_states
# Transformer blocks.
for layer_idx, layer in enumerate(self.layers):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
adjusted_per_layer_inputs = self.get_per_layer_inputs(
hidden_states_0, per_layer_inputs)
hidden_states = self.altup_embed(hidden_states_0)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states, adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nCrossDecoder(nn.Module):
"""
Cross-decoder layers
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
per_layer_input=per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
cache_config.kv_sharing_fast_prefill)
class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.altup_unembed_projections = nn.ModuleList([
ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
# Allocate config.num_kv_shared_layers layers for self-decoder
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3nDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
first_kv_shared_layer_idx = (config.num_hidden_layers -
config.num_kv_shared_layers)
# NOTE(sarckk): importing this top level seems to cause issues
# during running of tests.
from vllm.compilation.backends import set_model_tag
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with set_model_tag("self_decoder"):
self.self_decoder = Gemma3nSelfDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma3nCrossDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size,
self.config.altup_num_inputs),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
self.per_layer_inputs = torch.zeros(
(max_num_tokens, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
@property
def embed_tokens(self):
return self.self_decoder.embed_tokens
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_input_embeddings(input_ids)
def fast_prefill_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (self.fast_prefill_enabled and attn_metadata is not None):
assert isinstance(attn_metadata, dict)
# Last layer is a KV sharing layer
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name]
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
logits_indices_padded = (
layer_attn_metadata.logits_indices_padded)
num_logits_indices = layer_attn_metadata.num_logits_indices
# Copy inputs for cudagraph
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs_adjusted = \
self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
positions.size(0),
dtype=positions.dtype,
device=positions.device,
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Copy inputs for cudagraph
num_padded_logits_indices = logits_indices_padded.size(0)
self.positions[:num_padded_logits_indices].copy_(
positions[logits_indices_padded])
self.hidden_states[:num_padded_logits_indices].copy_(
self_decoder_hidden_states[logits_indices_padded])
self.per_layer_inputs[:num_padded_logits_indices].copy_(
per_layer_inputs_adjusted[logits_indices_padded])
cross_decoder_hidden_states = self.cross_decoder(
positions=self.positions[:num_padded_logits_indices],
hidden_states=self.hidden_states[:num_padded_logits_indices],
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
**kwargs,
)
if num_logits_indices is not None:
assert num_logits_indices > 0
# Merge cross-decoder and self-decoder hidden states
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_decoder_hidden_states[:num_logits_indices])
else:
hidden_states = cross_decoder_hidden_states
return hidden_states
def normal_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
hidden_states = self.cross_decoder(
positions=positions,
hidden_states=hidden_states,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
return hidden_states
def altup_unembed(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Altup unembed.
target_magnitude = torch.mean(hidden_states[0]**2,
target_magnitude = torch.mean(hidden_states[..., 0]**2,
dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[i] = self.altup_unembed_projections[i - 1](
hidden_states[i])
new_magnitude = torch.mean(hidden_states[i]**2,
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
hidden_states[..., i])
new_magnitude = torch.mean(hidden_states[..., i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=0)
hidden_states[..., i] *= target_magnitude / torch.maximum(
new_magnitude, EPS)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=-1)
return hidden_states
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
else:
hidden_states = self.normal_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.altup_unembed(hidden_states)
return self.norm(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# decoder layer weights, altup_unembed_projections and rmsnorm
# are initialized in text model, others are in self decoder
if (not name.startswith('layers')
and not name.startswith('altup_unembed_projections')
and not name.startswith('norm')):
name = f"self_decoder.{name}"
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for compressed-tensors quantization

View File

@ -308,13 +308,11 @@ class GraniteModel(nn.Module):
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states *= self.config.embedding_multiplier
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
@ -322,7 +320,6 @@ class GraniteModel(nn.Module):
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
@ -475,10 +472,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -298,17 +298,14 @@ class GraniteMoeModel(nn.Module):
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.embedding_multiplier
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
@ -523,10 +520,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -195,17 +195,14 @@ class GraniteMoeSharedModel(nn.Module):
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.embedding_multiplier
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
@ -323,10 +320,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors

View File

@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
Union)
import torch
import torch.nn as nn
@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
@ -622,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
@ -630,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
return json_map_leaves(select_features, image_features)
def _process_image_pixels(
self,

Some files were not shown because too many files have changed in this diff Show More