diff --git a/.buildkite/scripts/run-prime-rl-test.sh b/.buildkite/scripts/run-prime-rl-test.sh new file mode 100755 index 0000000000000..5b25c358fc4aa --- /dev/null +++ b/.buildkite/scripts/run-prime-rl-test.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Setup script for Prime-RL integration tests +# This script prepares the environment for running Prime-RL tests with nightly vLLM + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git" +PRIME_RL_DIR="${REPO_ROOT}/prime-rl" + +echo "Setting up Prime-RL integration test environment..." + +# Clean up any existing Prime-RL directory +if [ -d "${PRIME_RL_DIR}" ]; then + echo "Removing existing Prime-RL directory..." + rm -rf "${PRIME_RL_DIR}" +fi + +# Install UV if not available +if ! command -v uv &> /dev/null; then + echo "Installing UV package manager..." + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +fi + +# Clone Prime-RL repository at specific branch for reproducible tests +PRIME_RL_BRANCH="integ-vllm-main" +echo "Cloning Prime-RL repository at branch: ${PRIME_RL_BRANCH}..." +git clone --branch "${PRIME_RL_BRANCH}" --single-branch "${PRIME_RL_REPO}" "${PRIME_RL_DIR}" +cd "${PRIME_RL_DIR}" + +echo "Setting up UV project environment..." +export UV_PROJECT_ENVIRONMENT=/usr/local +ln -s /usr/bin/python3 /usr/local/bin/python + +# Remove vllm pin from pyproject.toml +echo "Removing vllm pin from pyproject.toml..." +sed -i '/vllm==/d' pyproject.toml + +# Sync Prime-RL dependencies +echo "Installing Prime-RL dependencies..." +uv sync --inexact && uv sync --inexact --all-extras + +# Verify installation +echo "Verifying installations..." +uv run python -c "import vllm; print(f'vLLM version: {vllm.__version__}')" +uv run python -c "import prime_rl; print('Prime-RL imported successfully')" + +echo "Prime-RL integration test environment setup complete!" + +echo "Running Prime-RL integration tests..." +export WANDB_MODE=offline # this makes this test not require a WANDB_API_KEY +uv run pytest -vs tests/integration/test_rl.py -m gpu + +echo "Prime-RL integration tests completed!" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aef6d709722fa..200ed344c4e86 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -887,6 +887,8 @@ steps: - tests/v1/test_external_lb_dp.py - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ + - vllm/v1/worker/ + - tests/v1/worker/test_worker_memory_snapshot.py commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py @@ -908,6 +910,7 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s models/multimodal/generation/test_maverick.py + - pytest -v -s v1/worker/test_worker_memory_snapshot.py - label: Plugin Tests (2 GPUs) # 40min timeout_in_minutes: 60 @@ -1042,3 +1045,15 @@ steps: commands: - pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + +##### RL Integration Tests ##### +- label: Prime-RL Integration Test # 15min + timeout_in_minutes: 30 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + source_file_dependencies: + - vllm/ + - .buildkite/scripts/run-prime-rl-test.sh + commands: + - bash .buildkite/scripts/run-prime-rl-test.sh diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 73b4aa5a87e07..a0350625491f4 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -449,7 +449,8 @@ async def benchmark( def prepare_extra_body(request) -> dict: extra_body = {} # Add the schema to the extra_body - extra_body[request.structure_type] = request.schema + extra_body["structured_outputs"] = {} + extra_body["structured_outputs"][request.structure_type] = request.schema return extra_body print("Starting initial single prompt test run...") diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py new file mode 100644 index 0000000000000..b419b2fa0e3eb --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe +kernel. Both kernels take in fp8 quantized weights and 16-bit activations, +but use different quantization strategies and backends. +""" + +import nvtx +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +# Weight shapes for different models: [num_experts, topk, hidden_size, +# intermediate_size] +WEIGHT_SHAPES_MOE = { + "mixtral-8x7b": [ + [8, 2, 4096, 14336], + ], + "deepseek-v2": [ + [160, 6, 5120, 12288], + ], + "custom-small": [ + [8, 2, 2048, 7168], + ], + "glm45-fp8": [ + [128, 8, 4096, 1408], + ], + "Llama-4-Maverick-17B-128E-Instruct-FP8": [ + [128, 1, 5120, 8192], + ], +} + +DEFAULT_MODELS = [ + "mixtral-8x7b", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + +FP8_DTYPE = current_platform.fp8_dtype() + + +def bench_run( + results: list, + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + + # Create input activations + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + + # Create weights + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + # Create FP8 quantized weights and scales for both kernels + w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE) + + # Create scales based on quantization strategy + if per_out_ch: + # Per-channel quantization + w1_scale = torch.empty( + (num_experts, 2 * n, 1), device=device, dtype=torch.float32 + ) + w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32) + else: + # Per-tensor quantization + w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + # Quantize weights + for expert in range(num_experts): + if per_out_ch: + # Per-channel quantization - not yet implemented properly + # For now, fall back to per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Expand scalar scales to the expected per-channel shape + w1_scale[expert] = w1_scale_temp.expand(2 * n, 1) + w2_scale[expert] = w2_scale_temp.expand(k, 1) + else: + # Per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Store scalar scales in [1, 1] tensors + w1_scale[expert, 0, 0] = w1_scale_temp + w2_scale[expert, 0, 0] = w2_scale_temp + + # Prepare weights for CUTLASS (no transpose needed) + w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K] + w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N] + + # Create router scores and get topk + score = torch.randn((m, num_experts), device=device, dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + # WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization + # Force per-tensor quantization for all cases to match working e2e setup + a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + + # Force per-tensor quantization for all cases + per_act_token = False + + # Create stride tensors for CUTLASS + ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device) + ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device) + c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device) + c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def run_cutlass_moe_fp8( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp8", color="blue"): + cutlass_moe_fp8( + a=a, + w1_q=w1, + w2_q=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + quant_config=quant_config, + activation="silu", + global_num_experts=num_experts, + ) + + # Pre-create quantization config to avoid creating it inside CUDA graph + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + # Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly) + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + cutlass_moe_fp8( + a=a, + w1_q=w1_fp8q_cutlass, + w2_q=w2_fp8q_cutlass, + topk_weights=topk_weights, + topk_ids=topk_ids, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + quant_config=quant_config, + activation="silu", + global_num_experts=num_experts, + ) + torch.cuda.synchronize() + + # Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly) + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + fused_experts( + a, + w1_fp8q, + w2_fp8q, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + torch.cuda.synchronize() + + def bench_cuda_graph(graph, num_warmup=5, num_iters=100): + """Benchmark CUDA graph using events like benchmark_moe.py""" + # Warmup + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + # Timing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + + # Divide by 10 since graph contains 10 calls + return sum(latencies) / (num_iters * 10) + + # Benchmark parameters + num_warmup = 5 + num_iters = 100 + + # Benchmark only CUDA graphs (more reliable and faster) + # Benchmark Triton MoE with CUDA graphs + triton_graph_time = bench_cuda_graph( + triton_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Benchmark CUTLASS MoE with CUDA graphs + cutlass_graph_time = bench_cuda_graph( + cutlass_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Convert ms to us and return results + triton_time_us = triton_graph_time * 1000 + cutlass_time_us = cutlass_graph_time * 1000 + + return { + "batch_size": m, + "triton_time_us": triton_time_us, + "cutlass_time_us": cutlass_time_us, + } + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + all_results = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in args.per_act_token_opts: + for per_out_ch in args.per_out_ch_opts: + print( + f"\n=== {model}, experts={num_experts}, topk={topk}," + f"per_act={per_act_token}, per_out_ch={per_out_ch} ===" + ) + + config_results = [] + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + result = bench_run( + [], # Not used anymore + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + if result: + config_results.append(result) + + # Print results table for this configuration + if config_results: + print( + f"\n{'Batch Size':<12}" + f"{'Triton (us)':<15}" + f"{'CUTLASS (us)':<15}" + ) + print("-" * 45) + for result in config_results: + print( + f"{result['batch_size']:<12}" + f"{result['triton_time_us']:<15.2f}" + f"{result['cutlass_time_us']:<15.2f}" + ) + + all_results.extend(config_results) + + print(f"\nTotal benchmarks completed: {len(all_results)}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE + across specified models/shapes/batches + + Example usage: + python benchmark_cutlass_moe_fp8.py \ + --model "Llama-4-Maverick-17B-128E-Instruct-FP8" \ + --tp-sizes 8 \ + --batch-size 2 4 8 \ + --per-act-token-opts false \ + --per-out-ch-opts false + + """ + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument( + "--per-act-token-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-activation token quantization options (true/false)", + ) + parser.add_argument( + "--per-out-ch-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-output channel quantization options (true/false)", + ) + + args = parser.parse_args() + main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 06494463223bd..2a2ec08f86951 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -258,7 +258,8 @@ set(VLLM_EXT_SRC "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp") + "csrc/cpu/torch_bindings.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index fbbc2e588c326..297d94dcc0631 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -135,10 +135,10 @@ public: max_splits = min(16, max_splits); // TODO: This avoids a hang when the batch size larger than 1 and - // there is more than 4 kv_splits. + // there is more than 1 kv_splits. // Discuss with NVIDIA how this can be fixed. if (B > 1) { - max_splits = min(2, max_splits); + max_splits = min(1, max_splits); } // printf(" max_splits = %d\n", max_splits); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 98c3ebc5a75f8..d279c03e0b59a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + ops.def( + "dynamic_4bit_int_moe(" + "Tensor x, Tensor topk_ids, Tensor topk_weights," + "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2," + "int group_size, bool apply_router_weight_on_input, int activation_kind" + ") -> Tensor"); + + ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu); + // PagedAttention V2. ops.def( "paged_attention_v2(" diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp new file mode 100644 index 0000000000000..1d06fc6b5b0a0 --- /dev/null +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -0,0 +1,156 @@ +#include +#include +#include + +// _dyn_quant_matmul_4bit is only available on AArch64. +#if defined(__aarch64__) + #include +#endif + +inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w, + int64_t group_size_eff, int64_t in_features, + int64_t out_features) { +#if defined(__aarch64__) + return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff, + in_features, out_features); +#else + TORCH_CHECK(false, + "dynamic 4-bit int MoE path requires AArch64 (ARM64); " + "_dyn_quant_matmul_4bit is unavailable on this architecture"); + return {}; +#endif +} + +enum ActivationKind : int64_t { + SwiGLU_Gu = 0, // act = SiLU(g) * u + SwiGLUOAI = 1, // act = SiLU(u) * g + SiLU = 2 // SiLU +}; + +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind) { + TORCH_CHECK(x.dim() == 2, "x must be 2D"); + TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2, + "topk tensors must be [T, K]"); + TORCH_CHECK( + w13_packed.size(0) == w2_packed.size(0), + "w13_packed and w2_packed must have same number of experts in dim 0"); + TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I"); + + const int64_t T = x.size(0); + const int64_t K = topk_ids.size(1); + const int64_t E = w13_packed.size(0); + const int64_t N = T * K; + + auto x_c = x.contiguous(); + auto ids_c = topk_ids.contiguous(); + auto gates_c = topk_weights.to(at::kFloat).contiguous(); + + // bucketing tokens -> experts + c10::SmallVector counts( + E, 0); // Small vector uses stack allocation + { + const auto* ids_ptr = ids_c.data_ptr(); + for (int64_t i = 0; i < N; ++i) { + const int64_t e_id = ids_ptr[i]; + TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range"); + counts[e_id]++; + } + } + c10::SmallVector offsets(E + 1, 0); // ( E +1 ) + for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e]; + + auto expert_tokens = at::empty({offsets[E]}, ids_c.options()); + auto expert_gates = at::empty({offsets[E]}, gates_c.options()); + { + c10::SmallVector cursor(E, 0); + const auto* ids_ptr = ids_c.data_ptr(); + const auto* gts_ptr = gates_c.data_ptr(); + auto* tok_ptr = expert_tokens.data_ptr(); + auto* gate_ptr = expert_gates.data_ptr(); + + for (int64_t t = 0; t < T; ++t) { + const int64_t base = t * K; + for (int64_t k = 0; k < K; ++k) { + const int64_t idx = base + k; + const int64_t e = ids_ptr[idx]; + const int64_t p = offsets[e] + (cursor[e]++); + tok_ptr[p] = t; + gate_ptr[p] = gts_ptr[idx]; + } + } + } + + const int64_t g_eff_13 = (group_size != -1) ? group_size : H; + const int64_t g_eff_2 = (group_size != -1) ? group_size : I; + + // Per-expert outputs filled in parallel + std::vector y_list(E); + y_list.resize(E); + + at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + for (int64_t e = e_begin; e < e_end; ++e) { + const int64_t te = counts[e]; + if (te == 0) { + y_list[e] = at::empty({0, H}, x_c.options()); + continue; + } + + const int64_t start = offsets[e]; + + auto sel_tokens = + expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + auto gates_e = + expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + + auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); + + if (apply_router_weight_on_input) { + x_e = x_e.mul(gates_e.unsqueeze(1)); + } + + auto w13_e = w13_packed.select(/*dim=*/0, e); + auto w2_e = w2_packed.select(/*dim=*/0, e); + + // W13 + auto y13 = + mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2); + + auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); + auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); + + torch::Tensor act; + if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI + constexpr double kAlpha = 1.702; // GPT-OSS default + constexpr double kLimit = 7.0; // GPT-OSS default + auto gate_c = at::clamp_max(g_part, kLimit); + auto up_c = at::clamp(u_part, -kLimit, kLimit); + auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha))); + act = up_c.add(1.0).mul(glu); + } else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul() + act = at::silu(g_part).mul(u_part); + } + + // W2 + auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); + + if (!apply_router_weight_on_input) { + y = y.mul(gates_e.unsqueeze(1)); + } + + // Store per-expert result + y_list[e] = y; + } + }); + + // Concatenate all expert outputs to match expert_tokens order + auto Y_all = at::cat(y_list, /*dim=*/0); + auto out = at::zeros({T, H}, x.options()); + out = + at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); + + return out; +} diff --git a/csrc/ops.h b/csrc/ops.h index fd9c55b948959..2ada7905da4b5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind); + using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 9aa1411b4a25c..b94cc9ce5086c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -23,9 +23,14 @@ typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16_raw __nv_bfloat16_raw; - + #if defined(HIP_FP8_TYPE_OCP) typedef __hip_fp8_e4m3 __nv_fp8_e4m3; typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; + #else +// ROCm 6.2 fallback: only *_fnuz types exist +typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3; + #endif #endif #include "core/registration.h" diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index dac9df6048f2a..133a545045b12 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -25,6 +25,12 @@ #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" +// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent +#if !defined(HIP_FP8_TYPE_OCP) +using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz; +using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz; +#endif + #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) #define __HIP__GFX9__ diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 2c69304db3393..fe065b52268a6 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -34,7 +34,7 @@ Now supports 5 types of connectors: For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: ```bash - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}' + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}' ``` - **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index f823d33df80ea..e44a914c726db 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok 1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. -2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'` +2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'` 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. diff --git a/requirements/common.txt b/requirements/common.txt index 7973da080c37d..a52745f698703 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -24,7 +24,7 @@ outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/conftest.py b/tests/conftest.py index dc70c98359598..a50985a465e6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1079,7 +1079,7 @@ def dummy_llava_path(): local_dir=_dummy_llava_path, ignore_patterns=[ "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" + "*.msgpack", "*.safetensors" ]) assert os.path.exists(json_path) with open(json_path) as f: @@ -1098,7 +1098,7 @@ def dummy_gemma2_embedding_path(): local_dir=_dummy_gemma2_embedding_path, ignore_patterns=[ "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" + "*.msgpack", "*.safetensors" ]) assert os.path.exists(json_path) with open(json_path) as f: diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 073b362b64749..aa28ed9ce25e5 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -382,7 +382,6 @@ def test_tp_language_generation( test_options: PPTestOptions, num_gpus_available, ): - pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, @@ -410,7 +409,6 @@ def test_tp_language_embedding( test_options: PPTestOptions, num_gpus_available, ): - pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, @@ -438,7 +436,6 @@ def test_tp_multimodal_generation( test_options: PPTestOptions, num_gpus_available, ): - pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 23d8373d97809..c28970afc731f 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -523,6 +523,7 @@ async def test_function_calling(client: OpenAI, model_name: str): input="What's the weather like in Paris today?", tools=tools, temperature=0.0, + extra_body={"request_id": "test_function_calling_non_resp"}, ) assert response is not None assert response.status == "completed" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8e68699e5904a..b773061b3092b 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -194,6 +194,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, assert tc.function is not None and tc.function.name == "get_current_weather" args1 = tc.function.arguments assert args1 is not None and len(args1) > 0 + assert not first_msg.content messages.append({"role": "assistant", "content": args1}) messages.append({ diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index a4e200775c09d..730514eb5a568 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -85,8 +85,7 @@ def test_env( if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, None, block_size, - False) + backend = get_attn_backend(16, torch.float16, None, block_size) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": @@ -106,7 +105,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) @@ -117,7 +115,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) @@ -127,7 +124,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = f"{name}_VLLM_V1" assert backend.get_name() == expected @@ -136,7 +132,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "TRITON_ATTN_VLLM_V1" assert backend.get_name() == expected @@ -164,7 +159,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "CUTLASS_MLA_VLLM_V1" assert backend.get_name() == expected @@ -179,7 +173,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASHINFER_MLA" assert backend.get_name() == expected @@ -199,7 +192,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = f"{name}_VLLM_V1" assert backend.get_name() == expected @@ -208,7 +200,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected @@ -218,7 +209,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "TRITON_MLA_VLLM_V1" assert backend.get_name() == expected @@ -227,7 +217,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASHINFER_VLLM_V1" assert backend.get_name() == expected @@ -236,7 +225,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) expected = "FLASH_ATTN_VLLM_V1" assert backend.get_name() == expected @@ -245,7 +233,6 @@ def test_env( torch.float16, None, block_size, - False, use_mla=use_mla) assert backend.get_name() == "FLEX_ATTENTION", ( "Should fallback to FlexAttention if head size is " @@ -264,13 +251,13 @@ def test_fp32_fallback( if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16, False) + backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16, False) + backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "FLEX_ATTENTION" @@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16, False) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Reset the monkeypatch for subsequent tests monkeypatch.undo() # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + backend = get_attn_backend(16, torch.float16, "fp8", 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8, False) + backend = get_attn_backend(16, torch.float16, None, 8) assert backend.get_name() != STR_FLASH_ATTN_VAL # flash-attn is not installed import sys original_module = sys.modules.get('vllm_flash_attn') monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) - backend = get_attn_backend(16, torch.float16, None, 16, False) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Restore the original module if it existed @@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, None, 16, True) + backend = get_attn_backend(17, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL @@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch): # Should raise ValueError for invalid backend with pytest.raises(ValueError) as exc_info: - get_attn_backend(32, torch.float16, None, 16, False) + get_attn_backend(32, torch.float16, None, 16) assert "Invalid value 'INVALID'" in str(exc_info.value) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0941cc3f608e7..4eb8e0cfaa5d0 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b678313752d65..d5d5bfaa3b45b 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import InputProcessingContext from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs -from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 9b376f2a260ac..4aa7bb7297893 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -42,7 +42,6 @@ def test_oot_registration_text_generation( assert rest == "" -@pytest.mark.skip(reason="This test is skipped because it failed on V1.") @create_new_process_for_each_test() def test_oot_registration_embedding( monkeypatch: pytest.MonkeyPatch, @@ -63,7 +62,6 @@ def test_oot_registration_embedding( image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") -@pytest.mark.skip(reason="This test is skipped because it failed on V1.") @create_new_process_for_each_test() def test_oot_registration_multimodal( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/models/utils.py b/tests/models/utils.py index 5da2382cef814..f80e92ebb3e2f 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -11,8 +11,9 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config import ModelConfig, ModelDType, RunnerOption -from vllm.inputs import InputContext from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs +from vllm.multimodal.processing import InputProcessingContext +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .registry import HF_EXAMPLE_MODELS @@ -264,7 +265,7 @@ def build_model_context( limit_mm_per_prompt: Optional[dict[str, int]] = None, mm_processor_cache_gb: int = 0, ): - """Creates an InputContext for a given model. + """Creates an InputProcessingContext for a given model. Args: model_id: ID of the model being considered. @@ -273,7 +274,7 @@ def build_model_context( limit_mm_per_prompt: Multimodal limits. Returns: - InputContext for the model being considered. + InputProcessingContext for the model being considered. """ model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -298,7 +299,11 @@ def build_model_context( enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) - return InputContext(model_config) + + return InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) def check_embeddings_close( diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 6ce5fcfe644bd..352b5b5b4fd46 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -8,11 +8,11 @@ import numpy as np import pytest from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY # yapf conflicts with isort for this block # yapf: disable -from vllm.multimodal.processing import (PlaceholderFeaturesInfo, +from vllm.multimodal.processing import (InputProcessingContext, + PlaceholderFeaturesInfo, PromptIndexTargets, PromptInsertion, PromptReplacement, apply_text_matches, apply_token_matches, diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py new file mode 100644 index 0000000000000..6a939dcfc2c9c --- /dev/null +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +# Create a concrete test implementation of BaseThinkingReasoningParser +class TestThinkingReasoningParser(BaseThinkingReasoningParser): + """Test implementation of BaseThinkingReasoningParser.""" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + +class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser): + """Alternative test implementation with different tokens.""" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + +# Use a test model +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def test_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom test tokens + test_tokens = ["", "", "", ""] + existing_tokens = set(tokenizer.get_vocab().keys()) + new_tokens = [ + token for token in test_tokens if token not in existing_tokens + ] + if new_tokens: + tokenizer.add_tokens(new_tokens) + return tokenizer + + +class TestBaseThinkingReasoningParserInit: + """ + Test initialization and basic properties of + BaseThinkingReasoningParser. + """ + + def test_successful_initialization(self, test_tokenizer): + """Test successful initialization with valid tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + assert parser.start_token == "" + assert parser.end_token == "" + assert parser.start_token_id is not None + assert parser.end_token_id is not None + + def test_initialization_with_missing_tokenizer(self): + """Test that initialization fails without tokenizer.""" + with pytest.raises(ValueError, match="model tokenizer must be passed"): + TestThinkingReasoningParser(None) + + def test_initialization_with_missing_tokens(self, test_tokenizer): + """Test that initialization fails when tokens are not in vocabulary.""" + + # Create a parser with tokens not in vocabulary + class MissingTokenParser(BaseThinkingReasoningParser): + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + with pytest.raises(RuntimeError, + match="could not locate think start/end tokens"): + MissingTokenParser(test_tokenizer) + + def test_initialization_with_empty_tokens(self, test_tokenizer): + """Test that initialization fails with empty token strings.""" + + class EmptyTokenParser(BaseThinkingReasoningParser): + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + with pytest.raises(ValueError, + match="start_token and end_token must be defined"): + EmptyTokenParser(test_tokenizer) + + +class TestBaseThinkingReasoningParserMethods: + """Test the methods of BaseThinkingReasoningParser.""" + + def test_is_reasoning_end(self, test_tokenizer): + """Test the is_reasoning_end method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token present + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + # Test with empty list + assert parser.is_reasoning_end([]) is False + + def test_extract_content_ids(self, test_tokenizer): + """Test the extract_content_ids method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test with end token as last element (should not extract) + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +class TestBaseThinkingReasoningParserExtraction: + """Test reasoning content extraction methods.""" + + def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer): + """Test extraction when both start and end tokens are present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("This is reasoning" + "This is content") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_only_end_token(self, test_tokenizer): + """Test extraction when only end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("This is reasoningThis is content") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_no_end_token(self, test_tokenizer): + """Test extraction when no end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "This is just content" + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "This is just content" + assert content is None + + def test_extract_reasoning_content_empty_output(self, test_tokenizer): + """Test extraction with empty output.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "" + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "" + assert content is None + + def test_extract_reasoning_content_only_tokens(self, test_tokenizer): + """Test extraction with only tokens and no content.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = ("") + reasoning, content = parser.extract_reasoning_content( + model_output, request) + + assert reasoning == "" + assert content is None + + +class TestBaseThinkingReasoningParserStreaming: + """Test streaming functionality of BaseThinkingReasoningParser.""" + + @pytest.mark.parametrize("streaming", [True, False]) + def test_simple_reasoning_extraction(self, test_tokenizer, streaming): + """ + Test basic reasoning extraction in both + streaming and non-streaming modes. + """ + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = [ + "", "Some ", "reasoning ", "content", "", + "Final ", "answer" + ] + + reasoning, content = run_reasoning_extraction(parser, + model_output, + streaming=streaming) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_incremental_deltas(self, test_tokenizer): + """Test streaming processing with small incremental deltas.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning ", + "content", + "", + "Final ", + "answer", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_start_token(self, test_tokenizer): + """Test streaming with start token included.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning", + "", + "Answer", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning" + assert content == "Answer" + + def test_streaming_no_end_token(self, test_tokenizer): + """Test streaming when no end token is encountered.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Some ", + "reasoning ", + "without ", + "end", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning without end" + assert content is None + + def test_streaming_only_end_token(self, test_tokenizer): + """Test streaming when only end token appears.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "", + "Reasoning ", + "content", + "", + "Final", + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Reasoning content" + assert content == "Final" + + +class TestBaseThinkingReasoningParserMultipleImplementations: + """ + Test that multiple implementations of + BaseThinkingReasoningParser work correctly. + """ + + def test_different_token_implementations(self, test_tokenizer): + """ + Test that different implementations + with different tokens work independently. + """ + parser1 = TestThinkingReasoningParser(test_tokenizer) + parser2 = TestThinkingReasoningParserAlt(test_tokenizer) + + # Test parser1 + model_output1 = ("Reasoning1Content1") + reasoning1, content1 = run_reasoning_extraction( + parser1, [model_output1]) + assert reasoning1 == "Reasoning1" + assert content1 == "Content1" + + # Test parser2 + model_output2 = "Reasoning2Content2" + reasoning2, content2 = run_reasoning_extraction( + parser2, [model_output2]) + assert reasoning2 == "Reasoning2" + assert content2 == "Content2" + + # Verify tokens are different + assert parser1.start_token != parser2.start_token + assert parser1.end_token != parser2.end_token + assert parser1.start_token_id != parser2.start_token_id + assert parser1.end_token_id != parser2.end_token_id + + +class TestBaseThinkingReasoningParserEdgeCases: + """Test edge cases and error conditions.""" + + def test_multiple_end_tokens(self, test_tokenizer): + """Test behavior with multiple end tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("FirstMiddleLast") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should stop at first end token + assert reasoning == "First" + assert content == "MiddleLast" + + def test_nested_tokens(self, test_tokenizer): + """Test behavior with nested-like token patterns.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("Outer" + "InnerContent") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should process normally, start from first start token + assert reasoning == "OuterInner" + assert content == "Content" + + def test_malformed_tokens(self, test_tokenizer): + """Test behavior with malformed token-like strings.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = ("Not a real token" + "Content") + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should treat as regular content since tokens don't match exactly + assert reasoning == ("Not a real token" + "Content") + assert content is None diff --git a/tests/reasoning/test_seedoss_reasoning_parser.py b/tests/reasoning/test_seedoss_reasoning_parser.py new file mode 100644 index 0000000000000..bb5dc0f4ffe4d --- /dev/null +++ b/tests/reasoning/test_seedoss_reasoning_parser.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, cast + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "seed_oss" +start_token = "" +end_token = "" + +# Use a test model that contains our custom tokens +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def seedoss_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom SeedOSS tokens if they don't exist + if start_token not in tokenizer.get_vocab(): + tokenizer.add_tokens([start_token, end_token]) + return tokenizer + + +SIMPLE_REASONING: dict[str, Any] = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING: dict[str, Any] = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_CONTENT: dict[str, Any] = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING: dict[str, Any] = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES: dict[str, Any] = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +WITH_START_TOKEN: dict[str, Any] = { + "output": ("This is a reasoning section" + "This is the rest"), + "reasoning_content": + "This is a reasoning section", + "content": + "This is the rest", + "is_reasoning_end": + True, +} +ONLY_END_TOKEN: dict[str, Any] = { + "output": "Some reasoningThis is the rest", + "reasoning_content": "Some reasoning", + "content": "This is the rest", + "is_reasoning_end": True, +} +NO_TOKENS: dict[str, Any] = { + "output": "This is just content without any reasoning tokens", + "reasoning_content": "This is just content without any reasoning tokens", + "content": None, + "is_reasoning_end": False, +} + + +def test_seedoss_reasoning_parser_creation(seedoss_tokenizer): + """Test that the SeedOSS reasoning parser can be created and registered.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + assert isinstance(parser, ReasoningParser) + assert parser.start_token == start_token + assert parser.end_token == end_token + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_simple_reasoning(seedoss_tokenizer, streaming): + """Test basic reasoning extraction with both tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming) + + assert reasoning == SIMPLE_REASONING["reasoning_content"] + assert content == SIMPLE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_complete_reasoning(seedoss_tokenizer, streaming): + """Test reasoning extraction when there's no content after reasoning.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming) + + assert reasoning == COMPLETE_REASONING["reasoning_content"] + assert content == COMPLETE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_content(seedoss_tokenizer, streaming): + """Test when there's no end token - everything is reasoning content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_CONTENT["output"])], streaming=streaming) + + assert reasoning == NO_CONTENT["reasoning_content"] + assert content == NO_CONTENT["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_multiple_lines(seedoss_tokenizer, streaming): + """Test reasoning extraction with multiline content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming) + + assert reasoning == MULTIPLE_LINES["reasoning_content"] + assert content == MULTIPLE_LINES["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_with_start_token(seedoss_tokenizer, streaming): + """Test reasoning extraction with both start and end tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming) + + assert reasoning == WITH_START_TOKEN["reasoning_content"] + assert content == WITH_START_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_only_end_token(seedoss_tokenizer, streaming): + """ + Test reasoning extraction with only end token + (SeedOSS typical behavior). + """ + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming) + + assert reasoning == ONLY_END_TOKEN["reasoning_content"] + assert content == ONLY_END_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tokens(seedoss_tokenizer, streaming): + """Test when there are no reasoning tokens at all.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_TOKENS["output"])], streaming=streaming) + + assert reasoning == NO_TOKENS["reasoning_content"] + assert content == NO_TOKENS["content"] + + +def test_is_reasoning_end(seedoss_tokenizer): + """Test the is_reasoning_end method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test with end token present + end_token_id = parser.end_token_id + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + +def test_extract_content_ids(seedoss_tokenizer): + """Test the extract_content_ids method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +def test_streaming_delta_processing(seedoss_tokenizer): + """Test streaming processing with small deltas.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test streaming with incremental tokens + deltas = [ + "Some ", "reasoning ", "content", "", "Final ", "answer" + ] + + reasoning, content = run_reasoning_extraction(parser, + deltas, + streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" diff --git a/tests/test_config.py b/tests/test_config.py index 0796447c079b6..9e2bfb9e1b0ec 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import MISSING, Field, asdict, dataclass, field +from unittest.mock import patch import pytest @@ -388,3 +390,108 @@ def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len, else: actual_max_len = model_config.get_and_verify_max_len(max_model_len) assert actual_max_len == expected_max_len + + +class MockConfig: + """Simple mock object for testing maybe_pull_model_tokenizer_for_runai""" + + def __init__(self, model: str, tokenizer: str): + self.model = model + self.tokenizer = tokenizer + self.model_weights = None + + +@pytest.mark.parametrize("s3_url", [ + "s3://example-bucket-1/model/", + "s3://example-bucket-2/model/", +]) +@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files') +def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url): + """Test that S3 URLs create deterministic local directories for model and + tokenizer.""" + # Mock pull_files to avoid actually downloading files during tests + mock_pull_files.return_value = None + + # Create first mock and run the method + config1 = MockConfig(model=s3_url, tokenizer=s3_url) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url) + + # Check that model and tokenizer point to existing directories + assert os.path.exists( + config1.model), f"Model directory does not exist: {config1.model}" + assert os.path.isdir( + config1.model), f"Model path is not a directory: {config1.model}" + assert os.path.exists( + config1.tokenizer + ), f"Tokenizer directory does not exist: {config1.tokenizer}" + assert os.path.isdir( + config1.tokenizer + ), f"Tokenizer path is not a directory: {config1.tokenizer}" + + # Verify that the paths are different from the original S3 URL + assert config1.model != s3_url, ( + "Model path should be converted to local directory") + assert config1.tokenizer != s3_url, ( + "Tokenizer path should be converted to local directory") + + # Store the original paths + created_model_dir = config1.model + create_tokenizer_dir = config1.tokenizer + + # Create a new mock and run the method with the same S3 URL + config2 = MockConfig(model=s3_url, tokenizer=s3_url) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url) + + # Check that the new directories exist + assert os.path.exists( + config2.model), f"Model directory does not exist: {config2.model}" + assert os.path.isdir( + config2.model), f"Model path is not a directory: {config2.model}" + assert os.path.exists( + config2.tokenizer + ), f"Tokenizer directory does not exist: {config2.tokenizer}" + assert os.path.isdir( + config2.tokenizer + ), f"Tokenizer path is not a directory: {config2.tokenizer}" + + # Verify that the paths are deterministic (same as before) + assert config2.model == created_model_dir, ( + f"Model paths are not deterministic. " + f"Original: {created_model_dir}, New: {config2.model}") + assert config2.tokenizer == create_tokenizer_dir, ( + f"Tokenizer paths are not deterministic. " + f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}") + + +@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files') +def test_s3_url_different_models_create_different_directories(mock_pull_files): + """Test that different S3 URLs create different local directories.""" + # Mock pull_files to avoid actually downloading files during tests + mock_pull_files.return_value = None + + s3_url1 = "s3://example-bucket-1/model/" + s3_url2 = "s3://example-bucket-2/model/" + + # Create mocks with different S3 URLs and run the method + config1 = MockConfig(model=s3_url1, tokenizer=s3_url1) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1) + + config2 = MockConfig(model=s3_url2, tokenizer=s3_url2) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2) + + # Verify that different URLs produce different directories + assert config1.model != config2.model, ( + f"Different S3 URLs should create different model directories. " + f"URL1 model: {config1.model}, URL2 model: {config2.model}") + assert config1.tokenizer != config2.tokenizer, ( + f"Different S3 URLs should create different tokenizer directories. " + f"URL1 tokenizer: {config1.tokenizer}, " + f"URL2 tokenizer: {config2.tokenizer}") + + # Verify that both sets of directories exist + assert os.path.exists(config1.model) and os.path.isdir(config1.model) + assert os.path.exists(config1.tokenizer) and os.path.isdir( + config1.tokenizer) + assert os.path.exists(config2.model) and os.path.isdir(config2.model) + assert os.path.exists(config2.tokenizer) and os.path.isdir( + config2.tokenizer) diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_use/test_deepseekv31_tool_parser.py new file mode 100644 index 0000000000000..5f6b266d3aa19 --- /dev/null +++ b/tests/tool_use/test_deepseekv31_tool_parser.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "deepseek-ai/DeepSeek-V3.1" + + +@pytest.fixture(scope="module") +def deepseekv31_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def parser(deepseekv31_tokenizer): + return DeepSeekV31ToolParser(deepseekv31_tokenizer) + + +def test_extract_tool_calls_with_tool(parser): + model_output = ( + "normal text" + "<|tool▁calls▁begin|>" + + "<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" + + "<|tool▁calls▁end|>") + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "foo" + assert result.tool_calls[0].function.arguments == "{\"x\":1}" + assert result.content == "normal text" + + +def test_extract_tool_calls_with_multiple_tools(parser): + model_output = ( + "some prefix text" + "<|tool▁calls▁begin|>" + + "<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" + + "<|tool▁call▁begin|>bar<|tool▁sep|>{\"y\":2}<|tool▁call▁end|>" + + "<|tool▁calls▁end|>" + " some suffix text") + + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called + assert len(result.tool_calls) == 2 + + assert result.tool_calls[0].function.name == "foo" + assert result.tool_calls[0].function.arguments == "{\"x\":1}" + + assert result.tool_calls[1].function.name == "bar" + assert result.tool_calls[1].function.arguments == "{\"y\":2}" + + # prefix is content + assert result.content == "some prefix text" diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py index 0192c7d2765cd..2551c41c62754 100644 --- a/tests/tool_use/test_openai_tool_parser.py +++ b/tests/tool_use/test_openai_tool_parser.py @@ -70,7 +70,12 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): assert extracted_info.content == "This is a test" -def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding): +@pytest.mark.parametrize("tool_args", [ + '{"location": "Tokyo"}', + '{\n"location": "Tokyo"\n}', +]) +def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding, + tool_args): convo = Conversation.from_messages([ Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"), @@ -80,7 +85,7 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding): ).with_channel("analysis"), Message.from_role_and_content( Role.ASSISTANT, - '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + tool_args).with_channel("commentary").with_recipient( "functions.get_current_weather").with_content_type("json"), ]) token_ids = harmony_encoding.render_conversation_for_completion( @@ -121,6 +126,17 @@ def test_extract_tool_calls_multiple_tools( Role.ASSISTANT, '{"location": "Tokyo"}').with_channel("commentary").with_recipient( "functions.get_user_location").with_content_type("json"), + Message.from_role_and_content( + Role.ASSISTANT, '{"location": "Tokyo"}').with_channel( + "commentary").with_recipient("functions.no_content_type"), + Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel( + "commentary").with_recipient("functions.not_json_no_content_type"), + Message.from_role_and_content( + Role.ASSISTANT, '{}').with_channel("commentary").with_recipient( + "functions.empty_args").with_content_type("json"), + Message.from_role_and_content( + Role.ASSISTANT, '').with_channel("commentary").with_recipient( + "functions.no_args").with_content_type("json"), ]) token_ids = harmony_encoding.render_conversation_for_completion( convo, @@ -141,7 +157,63 @@ def test_extract_tool_calls_multiple_tools( ToolCall(function=FunctionCall( name="get_user_location", arguments=json.dumps({"location": "Tokyo"}), + )), + ToolCall(function=FunctionCall( + name="no_content_type", + arguments=json.dumps({"location": "Tokyo"}), + )), + ToolCall(function=FunctionCall( + name="not_json_no_content_type", + arguments="foo", + )), + ToolCall(function=FunctionCall( + name="empty_args", + arguments=json.dumps({}), + )), + ToolCall(function=FunctionCall( + name="no_args", + arguments="", )) ] assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) assert extracted_info.content is None + + +def test_extract_tool_calls_with_content( + openai_tool_parser, + harmony_encoding, +): + final_content = "This tool call will get the weather." + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, + final_content).with_channel("final"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )), + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content == final_content diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py index 636108e985816..5196a92cb727f 100644 --- a/tests/tpu/lora/test_lora.py +++ b/tests/tpu/lora/test_lora.py @@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch): def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=256, - max_seq_len_to_capture=256, max_num_seqs=8, tensor_parallel_size=tp, enable_lora=True, diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 7d7a46910be89..d81f3da7e9cd9 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -9,8 +9,9 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, - split_attn_metadata) -from vllm.v1.worker.ubatch_utils import create_ubatch_slices + split_attn_metadata, + split_decodes_and_prefills) +from vllm.v1.worker.ubatch_splitting import create_ubatch_slices @pytest.fixture @@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) +def apply_split_decodes_and_prefills(query_lens: list[int], + decode_threshold: int, + require_uniform: bool): + """Helper function to apply split_decodes_and_prefills and return + the results.""" + device = torch.device("cpu") + seq_lens = [10 * (i + 1) for i in range(len(query_lens))] + common_metadata = create_common_attn_metadata(BatchSpec( + seq_lens=seq_lens, query_lens=query_lens), + block_size=16, + device=device) + return split_decodes_and_prefills(common_metadata, + decode_threshold=decode_threshold, + require_uniform=require_uniform) + + +def test_split_decodes_and_prefills_nonuniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, False)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): + query_lens = [1, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 7 + assert num_prefills == 0 + assert num_decode_tokens == sum(query_lens) + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_nonuniform_mixed_batch(): + query_lens = [2, 1, 3, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, False)) + assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4 + assert num_prefills == 4 # 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 10 # 2 + 1 + 3 + 4 + assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, True)) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_uniform_all_short_decodes(): + query_lens = [2, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 2 + assert num_prefills == 5 + assert num_decode_tokens == 4 + assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2) + + +def test_split_decodes_and_prefills_uniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True)) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): + query_lens = [2, 2, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform + assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 6 # 2 + 2 + 2 + assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): + query_lens = [2, 1, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True)) + assert num_decodes == 1 # only the first 2 is taken as decode + assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform + assert num_decode_tokens == 2 # only the first 2 + assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens + + @pytest.mark.parametrize( "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", [ diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3cf9d93696767..37b4f9a08e40d 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor -from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - get_block_hash, get_group_id, +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock, get_block_hash, + get_group_id, get_request_block_hasher, hash_block_tokens, init_none_hash, make_block_hash_with_group_id) @@ -138,7 +139,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) # Check full block metadata parent_block_hash = None @@ -171,7 +172,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -207,7 +208,7 @@ def test_prefill(hash_fn): blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([6], ) + assert blocks is not None and blocks.get_block_ids() == ([6], ) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -227,7 +228,9 @@ def test_prefill(hash_fn): len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) + assert blocks is not None and blocks.get_block_ids() == ([ + 7, 8, 9, 10, 4, 5, 6, 3, 2, 1 + ], ) assert free_block_queue.num_free_blocks == 0 assert (free_block_queue.fake_free_list_head.next_free_block @@ -261,8 +264,9 @@ def test_prefill_hybrid_model(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, - 8], [9, 10, 11, 12]) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [ + 5, 6, 7, 8 + ], [9, 10, 11, 12]) # Check full block metadata parent_block_hash = None @@ -298,7 +302,7 @@ def test_prefill_hybrid_model(): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([13], [14], [15]) + assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: if block != manager.block_pool.null_block: @@ -309,14 +313,15 @@ def test_prefill_hybrid_model(): manager.free(req1) cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block) + manager.block_pool.cached_block_hash_to_block._cache) - def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], + def test_partial_request_hit(request_id: str, + hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids, block_size, sha256) for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block.pop( + manager.block_pool.cached_block_hash_to_block._cache.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == 3 @@ -324,7 +329,7 @@ def test_prefill_hybrid_model(): for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block[ + manager.block_pool.cached_block_hash_to_block._cache[ hash_with_group_id] = cached_block_hash_to_block_bak[ hash_with_group_id] manager.free(req) @@ -362,7 +367,8 @@ def test_prefill_hybrid_model(): # total cache miss. # The cache hit length of full attention is 1 * block_size. # The cache hit length of sliding window is 2 * block_size. - # Then it is cache miss as the two type of layers have different hit length. + # Then it is cache miss as the two type of layers + # have different hit length. test_partial_request_hit("8", [ make_block_hash_with_group_id(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[0], 1), @@ -406,7 +412,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata @@ -441,7 +447,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -478,6 +484,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes @@ -513,7 +520,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -558,7 +565,8 @@ def test_evict(): blocks = manager.allocate_slots(req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial + # 5 full + 1 partial + assert blocks is not None and len(blocks.blocks[0]) == 6 # 3 blocks. req1 = make_request("1", list(range(last_token_id, @@ -570,7 +578,7 @@ def test_evict(): blocks = manager.allocate_slots(req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 3 # 3 full blocks + assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -592,7 +600,7 @@ def test_evict(): blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([10], ) + assert blocks is not None and blocks.get_block_ids() == ([10], ) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -617,7 +625,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -631,7 +639,7 @@ def test_hash_block_correct_reuse(): blocks = manager.allocate_slots(req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert manager.block_pool.blocks[blocks.blocks[0] [0].block_id].block_hash is None @@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. @@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 # Free the blocks. @@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted(): blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 1 + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled(): blocks = manager.allocate_slots(req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 3 + assert blocks is not None and len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled(): blocks = manager.allocate_slots(req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks[0]) == 4 + assert blocks is not None and len(blocks.blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4)), block_size, sha256) @@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn): assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) - # Test that blocks that don't start from the beginning are cached correctly. + # Test that blocks that don't start from the beginning are cached + # correctly. blocks += [KVCacheBlock(block_id=2)] block_pool.cache_full_blocks( request=req, @@ -1101,7 +1110,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -1112,7 +1121,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == ([5], ) + assert blocks is not None and blocks.get_block_ids() == ([5], ) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block(): # Manually add all blocks to cached_blocks for block, block_hash in zip(pool.blocks, block_hashes): block.block_hash = block_hash - pool.cached_block_hash_to_block[block_hash][block.block_id] = block + pool.cached_block_hash_to_block.insert(block_hash, block) block0, block1, block2, block3 = pool.blocks - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block0.block_id: block0, - block3.block_id: block3 + block3.block_id: block3, }, - block_hash1: { - block1.block_id: block1 - }, - block_hash2: { - block2.block_id: block2 - } + block_hash1: block1, + block_hash2: block2, } # Evict block1 pool._maybe_evict_cached_block(block1) - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block0.block_id: block0, block3.block_id: block3 }, - block_hash2: { - block2.block_id: block2 - } + block_hash2: block2, } # Evict block0: block_hash0 entry should NOT be removed, as block3 # also use the same hash pool._maybe_evict_cached_block(block0) - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block3.block_id: block3 }, - block_hash2: { - block2.block_id: block2 - } + block_hash2: block2, } # Evict block2 pool._maybe_evict_cached_block(block2) - assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}} + assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}} # Evict block3 pool._maybe_evict_cached_block(block3) - assert pool.cached_block_hash_to_block == {} + assert pool.cached_block_hash_to_block._cache == {} @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) @@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window(): # Evict the first block in the request assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None - manager.block_pool.cached_block_hash_to_block.pop( + manager.block_pool.cached_block_hash_to_block._cache.pop( make_block_hash_with_group_id(block_hash_first_block, 0)) # New request @@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window(): # there will be no matched prefix. assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 + + +def test_block_lookup_cache_single_block_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + key2 = BlockHashWithGroupId(b"hash2") + block0 = KVCacheBlock(0) + block1 = KVCacheBlock(1) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key0 inserted + cache.insert(key0, block0) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key1 inserted + cache.insert(key1, block1) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 100) is None + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key0, block ID 0) + assert cache.pop(key0, 0) is block0 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 1) is None + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key1, block ID 1) + assert cache.pop(key1, 1) is block1 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + + +def test_block_lookup_cache_multi_blocks_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + block00 = KVCacheBlock(0) + block01 = KVCacheBlock(1) + block10 = KVCacheBlock(10) + block11 = KVCacheBlock(11) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + + cache.insert(key0, block00) + cache.insert(key0, block01) + cache.insert(key1, block10) + cache.insert(key1, block11) + + assert cache.get_one_block(key0) is block00 + assert cache.pop(key0, 0) is block00 + assert cache.get_one_block(key0) is block01 + assert cache.pop(key0, 1) is block01 + assert cache.get_one_block(key0) is None + assert cache.pop(key0, 2) is None + + assert cache.get_one_block(key1) is block10 + assert cache.pop(key1, 10) is block10 + assert cache.get_one_block(key1) is block11 + assert cache.pop(key1, 11) is block11 + assert cache.get_one_block(key1) is None + assert cache.pop(key1, 12) is None diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b70850a9bcff9..01b54ae56e90a 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix(): BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[ - make_block_hash_with_group_id(block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10]) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix(): BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[ - make_block_hash_with_group_id(block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10]) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index e2c686928cea1..5017c83025ba1 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -81,16 +81,6 @@ class CarDescription(BaseModel): car_type: CarType -def _load_json(s: str, backend: str) -> str: - if backend != "xgrammar": - return json.loads(s) - - # xgrammar specific workarounds - # https://github.com/mlc-ai/xgrammar/issues/286 - s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s) - return json.loads(s) - - def test_guided_decoding_deprecated(): with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"): @@ -177,7 +167,12 @@ def test_structured_output( if backend != 'lm-format-enforcer': assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {sample_json_schema}\nError: {e}") jsonschema.validate(instance=output_json, schema=sample_json_schema) # @@ -425,7 +420,12 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {json_schema}\nError: {e}") jsonschema.validate(instance=output_json, schema=json_schema) # @@ -468,7 +468,12 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {json_schema}\nError: {e}") jsonschema.validate(instance=output_json, schema=json_schema) if backend not in ["outlines", "lm-format-enforcer"]: diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py index 991070dc9239d..f39a8021a29ef 100644 --- a/tests/v1/tpu/test_tpu_int8.py +++ b/tests/v1/tpu/test_tpu_int8.py @@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, prompts = [ "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - "The greatest glory in living lies not in never falling,", ] answers = [ - "or, being injured, not kill, except in", - "without the heart, one can only see wrongly.", - "but in rising every time we fall. - Nelson" + "or kill a human being", ] with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: diff --git a/tests/v1/worker/test_worker_memory_snapshot.py b/tests/v1/worker/test_worker_memory_snapshot.py new file mode 100644 index 0000000000000..6faa6bcc591cb --- /dev/null +++ b/tests/v1/worker/test_worker_memory_snapshot.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing as mp +import os +import tempfile +from multiprocessing import Queue +from typing import Optional +from unittest.mock import patch + +import pytest +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import MemorySnapshot +from vllm.v1.worker.gpu_worker import (Worker, + init_worker_distributed_environment) + +# Global queue to track operation order across processes +_QUEUE: Optional[Queue] = None + + +def track_operation(operation: str, rank: int): + """Track when an operation happens and its rank.""" + if _QUEUE is not None: + _QUEUE.put((operation, rank)) + + +def make_operation_tracker(operation_name: str, original_func): + """Create a mock function that tracks when an operation is called. + + Args: + operation_name: Name to use when tracking this operation + original_func: The original function to wrap + + Returns: + A wrapper function that tracks the operation and calls the original + """ + + def wrapper(*args, **kwargs): + rank = int(os.environ.get("RANK", "-1")) + track_operation(operation_name, rank) + return original_func(*args, **kwargs) + + return wrapper + + +def worker_process(rank: int, world_size: int, distributed_init_method: str, + queue: Queue, error_queue: Queue): + """Worker process that initializes a GPU worker with proper tracking.""" + global _QUEUE + _QUEUE = queue + + try: + # Set environment variables + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Create vLLM config with small model + vllm_config = EngineArgs(model="facebook/opt-125m", + tensor_parallel_size=2, + load_format="dummy").create_engine_config() + + # Create worker + worker = Worker( + vllm_config=vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + # Get original functions before patching + original_init_worker = init_worker_distributed_environment + original_memory_snapshot_init = MemorySnapshot.__init__ + original_all_reduce = torch.distributed.all_reduce + + # Apply minimal patches to track operation order + init_patch = patch( + 'vllm.v1.worker.gpu_worker.init_worker_distributed_environment', + side_effect=make_operation_tracker("init_distributed", + original_init_worker)) + memory_patch = patch.object( + MemorySnapshot, '__init__', + make_operation_tracker("memory_snapshot", + original_memory_snapshot_init)) + all_reduce_patch = patch('torch.distributed.all_reduce', + side_effect=make_operation_tracker( + "nccl_all_reduce", original_all_reduce)) + + with init_patch, memory_patch, all_reduce_patch: + + # Initialize device (this is where we test the order) + worker.init_device() + + # Load model to ensure everything works + worker.load_model() + + # Signal success + queue.put(("success", rank)) + + except Exception as e: + error_queue.put((rank, str(e), type(e).__name__)) + raise + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs for tensor parallelism") +def test_init_distributed_is_called_before_memory_snapshot(): + """Test that distributed env is setup before memory snapshot. + + This test makes sure during worker initialization, the initial memory + snapshot is taken after distributed env is setup to include all the buffers + allocated by distributed env. + """ + world_size = 2 + + # Create a temporary file for distributed init + with tempfile.NamedTemporaryFile(delete=False) as f: + distributed_init_method = f"file://{f.name}" + + # Create queues for inter-process communication + ctx = mp.get_context("spawn") + operation_queue = ctx.Queue() + error_queue = ctx.Queue() + + # Start worker processes + processes = [] + for rank in range(world_size): + p = ctx.Process(target=worker_process, + args=(rank, world_size, distributed_init_method, + operation_queue, error_queue)) + p.start() + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.join(timeout=60) # 60 second timeout + + # Check for errors + errors = [] + while not error_queue.empty(): + rank, error_msg, error_type = error_queue.get() + errors.append(f"Rank {rank}: {error_type}: {error_msg}") + + if errors: + pytest.fail("Worker processes failed:\n" + "\n".join(errors)) + + # Collect all operations from the queue + operations = [] + while not operation_queue.empty(): + operations.append(operation_queue.get()) + + # Verify we got operations from both ranks + print(f"Collected operations: {operations}") + + # Check operations for each rank + for rank in range(world_size): + rank_ops = [op for op, r in operations if r == rank] + print(f"\nRank {rank} operations: {rank_ops}") + + # Raises ValueError if the operation is not found + init_distributed = rank_ops.index("init_distributed") + nccl_all_reduce = rank_ops.index("nccl_all_reduce") + memory_snapshot = rank_ops.index("memory_snapshot") + + # Verify order: init_distributed should happen before memory_snapshot + assert init_distributed < nccl_all_reduce < memory_snapshot, ( + f"Rank {rank}: init_distributed (index {init_distributed}) " + f"must happen before nccl_all_reduce (index {nccl_all_reduce}) " + f"and memory_snapshot (index {memory_snapshot})") + + # Clean up + os.unlink(distributed_init_method.replace("file://", "")) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py deleted file mode 100644 index cddeb2cf39bf0..0000000000000 --- a/vllm/attention/backends/placeholder_attn.py +++ /dev/null @@ -1,314 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from itertools import accumulate -from typing import List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.utils import async_tensor_h2d - -# Placeholder attention backend for models like Mamba and pooling models that -# lack attention. - - -class PlaceholderAttentionBackend(AttentionBackend): - """Placeholder backend for when no attention is needed.""" - - @staticmethod - def get_name() -> str: - return "NO_ATTENTION" - - @staticmethod - def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: - return PlaceholderAttentionImpl - - @staticmethod - def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: - return PlaceholderAttentionMetadataBuilder - - @staticmethod - def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: - return PlaceholderAttentionMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (1, 1, 1, 1, 1) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - return - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - return - - -@dataclass -class PlaceholderAttentionMetadata(AttentionMetadata): - """Attention metadata for prefill and decode batched together.""" - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # Placeholder. - block_tables: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None - _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - - self._cached_prefill_metadata = PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=0, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - - self._cached_decode_metadata = PlaceholderAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - -class PlaceholderAttentionMetadataBuilder( - AttentionMetadataBuilder[PlaceholderAttentionMetadata]): - - def __init__(self, input_builder): - - self.input_builder = input_builder - self.runner = input_builder.runner - - def prepare(self): - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.curr_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - """ - is_prompt = inter_data.is_prompt - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - - # Some input builders such as ModelInputForCPUBuilder do not have the - # "inter_data_list" attribute. - # Let's check inter_data_list exists before we reference it. - if hasattr(self.input_builder, "inter_data_list"): - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - if use_captured_graph: - num_decode_tokens = batch_size - self.num_prefill_tokens - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - # Placeholders - slot_mapping_tensor = torch.empty(0) - block_tables = torch.empty(0) - - return PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - enable_kv_scales_calculation=True, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class PlaceholderAttentionImpl(AttentionImpl): - - def __init__(self, *args, **kwargs) -> None: - return - - def forward(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 63ee8f50825c5..accb3ab6ae2b0 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState): max_query_len=1, max_decode_query_len=1, max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, + max_decode_seq_len=self.runner.max_model_len, query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, @@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState): dtype=torch.int).cuda() attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() - attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.max_encoder_seq_len = self.runner.max_model_len attn_metadata.num_encoder_tokens = 0 def _add_additional_input_buffers_for_enc_dec_model( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a720524429..baa83e29bdd05 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - is_attention_free = cache_config.is_attention_free calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 - is_attention_free = False calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads @@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase): dtype, kv_cache_dtype, block_size, - is_attention_free, use_mla=use_mla, has_sink=self.has_sink) else: @@ -578,9 +575,7 @@ def unified_attention_fake( direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, - mutates_args=[], fake_impl=unified_attention_fake, - dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) @@ -631,6 +626,5 @@ direct_register_custom_op( op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, - dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index 0b0c706626af3..883052cb46aab 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -2,10 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -import triton -import triton.language as tl from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton @triton.jit diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a235ba6e0b42..b651fc3eaee36 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -142,7 +142,6 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool = False, use_mla: bool = False, has_sink: bool = False, ) -> type[AttentionBackend]: @@ -156,7 +155,6 @@ def get_attn_backend( dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, - is_attention_free=is_attention_free, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, @@ -169,17 +167,10 @@ def _cached_get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, ) -> type[AttentionBackend]: - # If there are no attention layers (e.g. we are running Mamba), - # use the placeholder NO_ATTENTION - if is_attention_free: - from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) - return PlaceholderAttentionBackend # Check whether a particular choice of backend was # previously forced. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 331cd8a873929..04b76a9c2d228 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -547,7 +547,6 @@ if flashinfer_comm is not None: "scale_out", ], fake_impl=call_trtllm_fused_allreduce_norm_fake, - dispatch_key=current_platform.dispatch_key, ) flashinfer_trtllm_fused_allreduce_norm = ( torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7158fd685964f..eeca14d1296f3 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape): if isinstance(runtime_shape, int): # for a specific batchsize, tuning triton kernel parameters # can be beneficial - config["max_autotune"] = True - config["coordinate_descent_tuning"] = True + config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE + config["coordinate_descent_tuning"] = ( + envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING) class EagerAdaptor(CompilerInterface): diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index b7a6e23c1aa79..6e9a36a2b0b99 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -8,6 +8,7 @@ from unittest.mock import patch import torch import torch.nn as nn +from packaging import version from torch._dynamo.symbolic_convert import InliningInstructionTranslator from vllm.compilation.counter import compilation_counter @@ -300,13 +301,13 @@ def _support_torch_compile( logger.debug( "enable_cpp_symbolic_shape_guards config not available") - with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call), torch._dynamo.config.patch( - **dynamo_config_patches - ), maybe_use_cudagraph_partition_wrapper( - self.vllm_config): + with patch.object( + InliningInstructionTranslator, "inline_call", + patched_inline_call), torch._dynamo.config.patch( + **dynamo_config_patches + ), maybe_use_cudagraph_partition_wrapper( + self.vllm_config), _torch27_patch_tensor_subclasses(): output = self.compiled_callable(*args, **kwargs) - return output # usually, capturing the model once is enough, and then we can @@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE and compilation_config.use_inductor_graph_partition): torch._inductor.utils.set_customized_partition_wrappers(None) + + +@contextlib.contextmanager +def _torch27_patch_tensor_subclasses(): + """ + Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when + using torch 2.7.0. This enables using weight_loader_v2 and the use of + `BasevLLMParameters` without having to replace them with regular tensors + before `torch.compile`-time. + """ + from vllm.model_executor.parameter import (BasevLLMParameter, + ModelWeightParameter, + RowvLLMParameter, + _ColumnvLLMParameter) + + def return_false(*args, **kwargs): + return False + + if version.parse("2.7") <= version.parse( + torch.__version__) < version.parse("2.8"): + yield + return + + with (torch._dynamo.config.patch("traceable_tensor_subclasses", [ + BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter, + RowvLLMParameter + ]), + patch("torch._dynamo.variables.torch.can_dispatch_torch_function", + return_false)): + yield diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index d786d3e289b33..df6564077e8aa 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) +from vllm.config.device import Device, DeviceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig @@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode, try_match_architecture_defaults) from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, MultiModalConfig) +from vllm.config.observability import DetailedTraceModules, ObservabilityConfig from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy from vllm.config.speculative import SpeculativeConfig +from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field from vllm.logger import init_logger @@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol): ... -Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class DeviceConfig: - """Configuration for the device to use for vLLM execution.""" - - device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" - """Device type for vLLM execution. - This parameter is deprecated and will be - removed in a future release. - It will now be set automatically based - on the current platform.""" - device_type: str = field(init=False) - """Device type from the current platform. This is set in - `__post_init__`.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # the device/platform information will be summarized - # by torch/vllm automatically. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if self.device == "auto": - # Automated device type detection - from vllm.platforms import current_platform - self.device_type = current_platform.device_type - if not self.device_type: - raise RuntimeError( - "Failed to infer device type, please set " - "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " - "to turn on verbose logging to help debug the issue.") - else: - # Device type is assigned explicitly - if isinstance(self.device, str): - self.device_type = self.device - elif isinstance(self.device, torch.device): - self.device_type = self.device.type - - # Some device types require processing inputs on CPU - if self.device_type in ["tpu"]: - self.device = None - else: - # Set device with device type - self.device = torch.device(self.device_type) - - -DetailedTraceModules = Literal["model", "worker", "all"] - - -@config -@dataclass -class ObservabilityConfig: - """Configuration for observability - metrics and tracing.""" - - show_hidden_metrics_for_version: Optional[str] = None - """Enable deprecated Prometheus metrics that have been hidden since the - specified version. For example, if a previously deprecated metric has been - hidden since the v0.7.0 release, you use - `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while - you migrate to new metrics. The metric is likely to be removed completely - in an upcoming release.""" - - @cached_property - def show_hidden_metrics(self) -> bool: - """Check if the hidden metrics should be shown.""" - if self.show_hidden_metrics_for_version is None: - return False - return version._prev_minor_version_was( - self.show_hidden_metrics_for_version) - - otlp_traces_endpoint: Optional[str] = None - """Target URL to which OpenTelemetry traces will be sent.""" - - collect_detailed_traces: Optional[list[DetailedTraceModules]] = None - """It makes sense to set this only if `--otlp-traces-endpoint` is set. If - set, it will collect detailed traces for the specified modules. This - involves use of possibly costly and or blocking operations and hence might - have a performance impact. - - Note that collecting detailed timing information for each request can be - expensive.""" - - @cached_property - def collect_model_forward_time(self) -> bool: - """Whether to collect model forward time for the request.""" - return (self.collect_detailed_traces is not None - and ("model" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) - - @cached_property - def collect_model_execute_time(self) -> bool: - """Whether to collect model execute time for the request.""" - return (self.collect_detailed_traces is not None - and ("worker" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if (self.collect_detailed_traces is not None - and len(self.collect_detailed_traces) == 1 - and "," in self.collect_detailed_traces[0]): - self._parse_collect_detailed_traces() - - from vllm.tracing import is_otel_available, otel_import_error_traceback - if not is_otel_available() and self.otlp_traces_endpoint is not None: - raise ValueError( - "OpenTelemetry is not available. Unable to configure " - "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " - f"installed. Original error:\n{otel_import_error_traceback}") - - def _parse_collect_detailed_traces(self): - assert isinstance(self.collect_detailed_traces, list) - self.collect_detailed_traces = cast( - list[DetailedTraceModules], - self.collect_detailed_traces[0].split(",")) - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: @@ -1009,37 +860,6 @@ def get_layers_from_vllm_config( } -@config -@dataclass -class SpeechToTextConfig: - """Configuration for speech-to-text models.""" - - sample_rate: float = 16_000 - """Sample rate (Hz) to resample input audio to. Most speech models expect - 16kHz audio input. The input audio will be automatically resampled to this - rate before processing.""" - - max_audio_clip_s: int = 30 - """Maximum duration in seconds for a single audio clip without chunking. - Audio longer than this will be split into smaller chunks if - `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" - - overlap_chunk_second: int = 1 - """Overlap duration in seconds between consecutive audio chunks when - splitting long audio. This helps maintain context across chunk boundaries - and improves transcription quality at split points.""" - - min_energy_split_window_size: Optional[int] = 1600 - """Window size in samples for finding low-energy (quiet) regions to split - audio chunks. The algorithm looks for the quietest moment within this - window to minimize cutting through speech. Default 1600 samples ≈ 100ms - at 16kHz. If None, no chunking will be done.""" - - @property - def allow_audio_chunking(self) -> bool: - return self.min_energy_split_window_size is not None - - def update_config(config: DataclassInstanceT, overrides: dict[str, Any]) -> DataclassInstanceT: processed_overrides = {} diff --git a/vllm/config/device.py b/vllm/config/device.py new file mode 100644 index 0000000000000..4654ac96e0b7b --- /dev/null +++ b/vllm/config/device.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import Any, Literal, Optional, Union + +import torch +from pydantic import ConfigDict, SkipValidation +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class DeviceConfig: + """Configuration for the device to use for vLLM execution.""" + + device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.device == "auto": + # Automated device type detection + from vllm.platforms import current_platform + self.device_type = current_platform.device_type + if not self.device_type: + raise RuntimeError( + "Failed to infer device type, please set " + "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " + "to turn on verbose logging to help debug the issue.") + else: + # Device type is assigned explicitly + if isinstance(self.device, str): + self.device_type = self.device + elif isinstance(self.device, torch.device): + self.device_type = self.device.type + + # Some device types require processing inputs on CPU + if self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) diff --git a/vllm/config/model.py b/vllm/config/model.py index d8a8fe20fd030..f37489bdfff59 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -177,11 +177,6 @@ class ModelConfig: graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid for maximal performance and flexibility.""" - max_seq_len_to_capture: int = 8192 - """Maximum sequence len covered by CUDA graphs. When a sequence has context - length larger than this, we fall back to eager mode. Additionally for - encoder-decoder models, if the sequence length of the encoder input is - larger than this, we fall back to the eager mode.""" max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the @@ -699,11 +694,12 @@ class ModelConfig: model: Model name or path tokenizer: Tokenizer name or path """ + if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)): return if is_runai_obj_uri(model): - object_storage_model = ObjectStorageModel() + object_storage_model = ObjectStorageModel(url=model) object_storage_model.pull_files( model, allow_pattern=["*.model", "*.py", "*.json"]) self.model_weights = model @@ -722,7 +718,7 @@ class ModelConfig: # Only download tokenizer if needed and not already handled if is_runai_obj_uri(tokenizer): - object_storage_tokenizer = ObjectStorageModel() + object_storage_tokenizer = ObjectStorageModel(url=tokenizer) object_storage_tokenizer.pull_files(model, ignore_pattern=[ "*.pt", "*.safetensors", @@ -1023,21 +1019,8 @@ class ModelConfig: current_platform.verify_quantization(self.quantization) def _verify_cuda_graph(self) -> None: - # The `max_seq_len_to_capture` was incorrectly - # based on the encoder's input length (448) - # but not the decoder's larger input length (1500). - # This change ensures the CUDA Graph captures the correct, - # larger sequence length, allowing it to work as intended. - effective_max_seq_len = self.max_model_len - if self.is_encoder_decoder: - effective_max_seq_len = max( - effective_max_seq_len, - getattr(self.hf_config, "max_source_positions", 0)) - self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, - effective_max_seq_len) # CUDAGraph capture not supported for encoder-decoder models on ROCm unsupported_rocm = self.is_encoder_decoder - if (unsupported_rocm and not self.enforce_eager and current_platform.is_rocm()): logger.warning( diff --git a/vllm/config/observability.py b/vllm/config/observability.py new file mode 100644 index 0000000000000..766d03051e212 --- /dev/null +++ b/vllm/config/observability.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from functools import cached_property +from typing import Any, Literal, Optional, cast + +from pydantic.dataclasses import dataclass + +from vllm import version +from vllm.config.utils import config + +DetailedTraceModules = Literal["model", "worker", "all"] + + +@config +@dataclass +class ObservabilityConfig: + """Configuration for observability - metrics and tracing.""" + + show_hidden_metrics_for_version: Optional[str] = None + """Enable deprecated Prometheus metrics that have been hidden since the + specified version. For example, if a previously deprecated metric has been + hidden since the v0.7.0 release, you use + `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while + you migrate to new metrics. The metric is likely to be removed completely + in an upcoming release.""" + + @cached_property + def show_hidden_metrics(self) -> bool: + """Check if the hidden metrics should be shown.""" + if self.show_hidden_metrics_for_version is None: + return False + return version._prev_minor_version_was( + self.show_hidden_metrics_for_version) + + otlp_traces_endpoint: Optional[str] = None + """Target URL to which OpenTelemetry traces will be sent.""" + + collect_detailed_traces: Optional[list[DetailedTraceModules]] = None + """It makes sense to set this only if `--otlp-traces-endpoint` is set. If + set, it will collect detailed traces for the specified modules. This + involves use of possibly costly and or blocking operations and hence might + have a performance impact. + + Note that collecting detailed timing information for each request can be + expensive.""" + + @cached_property + def collect_model_forward_time(self) -> bool: + """Whether to collect model forward time for the request.""" + return (self.collect_detailed_traces is not None + and ("model" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + @cached_property + def collect_model_execute_time(self) -> bool: + """Whether to collect model execute time for the request.""" + return (self.collect_detailed_traces is not None + and ("worker" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if (self.collect_detailed_traces is not None + and len(self.collect_detailed_traces) == 1 + and "," in self.collect_detailed_traces[0]): + self._parse_collect_detailed_traces() + + from vllm.tracing import is_otel_available, otel_import_error_traceback + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") + + def _parse_collect_detailed_traces(self): + assert isinstance(self.collect_detailed_traces, list) + self.collect_detailed_traces = cast( + list[DetailedTraceModules], + self.collect_detailed_traces[0].split(",")) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d533930e1c7aa..34b17628def1f 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -285,8 +285,6 @@ class SpeculativeConfig: max_model_len, quantization=self.quantization, enforce_eager=self.target_model_config.enforce_eager, - max_seq_len_to_capture=self.target_model_config. - max_seq_len_to_capture, max_logprobs=self.target_model_config.max_logprobs, hf_overrides=SpeculativeConfig.hf_config_override, ) diff --git a/vllm/config/speech_to_text.py b/vllm/config/speech_to_text.py new file mode 100644 index 0000000000000..de9f525efe185 --- /dev/null +++ b/vllm/config/speech_to_text.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class SpeechToTextConfig: + """Configuration for speech-to-text models.""" + + sample_rate: float = 16_000 + """Sample rate (Hz) to resample input audio to. Most speech models expect + 16kHz audio input. The input audio will be automatically resampled to this + rate before processing.""" + + max_audio_clip_s: int = 30 + """Maximum duration in seconds for a single audio clip without chunking. + Audio longer than this will be split into smaller chunks if + `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" + + overlap_chunk_second: int = 1 + """Overlap duration in seconds between consecutive audio chunks when + splitting long audio. This helps maintain context across chunk boundaries + and improves transcription quality at split points.""" + + min_energy_split_window_size: Optional[int] = 1600 + """Window size in samples for finding low-energy (quiet) regions to split + audio chunks. The algorithm looks for the quietest moment within this + window to minimize cutting through speech. Default 1600 samples ≈ 100ms + at 16kHz. If None, no chunking will be done.""" + + @property + def allow_audio_chunking(self) -> bool: + return self.min_energy_split_window_size is not None diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 75de85e1b0aba..76fe9a93259fa 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm): direct_register_custom_op( op_name="all_reduce_symmetric_with_copy", op_func=all_reduce_symmetric_with_copy_impl, - mutates_args=[], fake_impl=all_reduce_symmetric_with_copy_fake, ) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c7810043b81e8..deeed1f21b4e1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -392,7 +392,8 @@ class MessageQueue: > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): logger.debug( ("No available shared memory broadcast block found" - " in %s second."), + " in %s seconds. This typically happens when some" + " processes are hanging."), VLLM_RINGBUFFER_WARNING_INTERVAL, ) n_warning += 1 @@ -455,7 +456,8 @@ class MessageQueue: > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): logger.debug( ("No available shared memory broadcast block found" - " in %s second."), + " in %s seconds. This typically happens when some" + " processes are hanging."), VLLM_RINGBUFFER_WARNING_INTERVAL, ) n_warning += 1 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 64feddb591c27..528d4022bd17a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -574,7 +574,6 @@ class NixlConnectorWorker: self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free, use_mla=self.use_mla) self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 895971893a661..69f98eb54f36c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, if supports_custom_op(): - from vllm.platforms import current_platform direct_register_custom_op( op_name="all_reduce", op_func=all_reduce, - mutates_args=[], fake_impl=all_reduce_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="reduce_scatter", op_func=reduce_scatter, - mutates_args=[], fake_impl=reduce_scatter_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="all_gather", op_func=all_gather, - mutates_args=[], fake_impl=all_gather_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 556a490ffa109..3f0dfce1b4b50 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -373,7 +373,6 @@ class EngineArgs: tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager - max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce limit_mm_per_prompt: dict[str, int] = \ get_field(MultiModalConfig, "limit_per_prompt") @@ -545,8 +544,6 @@ class EngineArgs: **model_kwargs["quantization"]) model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) - model_group.add_argument("--max-seq-len-to-capture", - **model_kwargs["max_seq_len_to_capture"]) model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", @@ -1008,7 +1005,6 @@ class EngineArgs: max_model_len=self.max_model_len, quantization=self.quantization, enforce_eager=self.enforce_eager, - max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, logprobs_mode=self.logprobs_mode, disable_sliding_window=self.disable_sliding_window, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c41f44aa47187..dfe535b959179 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -130,11 +130,6 @@ class LLM: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. Additionally for encoder-decoder models, if the - sequence length of the encoder input is larger than this, we fall - back to the eager mode. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. hf_token: The token to use as HTTP bearer authorization for remote files @@ -184,7 +179,6 @@ class LLM: swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, - max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, @@ -281,7 +275,6 @@ class LLM: swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, - max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, hf_token=hf_token, hf_overrides=hf_overrides, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 16564214e353a..0780448ad7332 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1186,6 +1186,10 @@ class OpenAIServingChat(OpenAIServing): logprobs = None if self.use_harmony: + reasoning_content, content, _ = parse_chat_output(token_ids) + if not request.include_reasoning: + reasoning_content = None + if self.tool_parser is not None: tool_parser = self.tool_parser(tokenizer) # NOTE: We use token_ids for openai tool parser @@ -1194,10 +1198,7 @@ class OpenAIServingChat(OpenAIServing): request=request, token_ids=token_ids, # type: ignore ) - reasoning_content, content = None, tool_call_info.content - if request.include_reasoning: - reasoning_content, content, _ = parse_chat_output( - token_ids) + content = tool_call_info.content message = ChatMessage( role=role, reasoning_content=reasoning_content, @@ -1205,10 +1206,6 @@ class OpenAIServingChat(OpenAIServing): tool_calls=tool_call_info.tool_calls, ) else: - reasoning_content, content, _ = parse_chat_output( - token_ids) - if not request.include_reasoning: - reasoning_content = None message = ChatMessage( role=role, reasoning_content=reasoning_content, diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 99bb464db1d13..c70baba88d433 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -235,8 +235,6 @@ class OpenAIServingResponses(OpenAIServing): # Handle the previous response ID. prev_response_id = request.previous_response_id if prev_response_id is not None: - if not prev_response_id.startswith("resp_"): - return self._make_invalid_id_error(prev_response_id) async with self.response_store_lock: prev_response = self.response_store.get(prev_response_id) if prev_response is None: @@ -924,9 +922,6 @@ class OpenAIServingResponses(OpenAIServing): stream: Optional[bool], ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[ StreamingResponsesResponse, None]]: - if not response_id.startswith("resp_"): - return self._make_invalid_id_error(response_id) - async with self.response_store_lock: response = self.response_store.get(response_id) @@ -944,9 +939,6 @@ class OpenAIServingResponses(OpenAIServing): self, response_id: str, ) -> Union[ErrorResponse, ResponsesResponse]: - if not response_id.startswith("resp_"): - return self._make_invalid_id_error(response_id) - async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: @@ -972,13 +964,6 @@ class OpenAIServingResponses(OpenAIServing): response_id) return response - def _make_invalid_id_error(self, response_id: str) -> ErrorResponse: - return self.create_error_response( - err_type="invalid_request_error", - message=(f"Invalid 'response_id': '{response_id}'. " - "Expected an ID that begins with 'resp'."), - ) - def _make_not_found_error(self, response_id: str) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py index ff9188190f3f0..09095f8991773 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -39,7 +39,7 @@ class DeepSeekV31ToolParser(ToolParser): self.tool_call_end_token: str = "<|tool▁call▁end|>" self.tool_call_regex = re.compile( - r"<|tool▁call▁begin|>(?P.*)<|tool▁sep|>(?P.*)<|tool▁call▁end|>" + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>" ) self.stream_tool_call_portion_regex = re.compile( diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py index c5d59514b9445..1729fdbc99710 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations +import json from collections.abc import Sequence from typing import TYPE_CHECKING @@ -12,10 +13,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) +from vllm.logger import init_logger if TYPE_CHECKING: from vllm.transformers_utils.tokenizer import AnyTokenizer +logger = init_logger(__name__) + @ToolParserManager.register_module("openai") class OpenAIToolParser(ToolParser): @@ -40,17 +44,33 @@ class OpenAIToolParser(ToolParser): if len(parser.messages) > 0: for msg in parser.messages: + if len(msg.content) < 1: + continue + msg_text = msg.content[0].text if msg.recipient and msg.recipient.startswith("functions."): + # If no content-type is given assume JSON, as that's the + # most common case with gpt-oss models. + if not msg.content_type or "json" in msg.content_type: + # load and dump the JSON text to check validity and + # remove any extra newlines or other odd formatting + try: + tool_args = json.dumps(json.loads(msg_text)) + except json.JSONDecodeError: + logger.exception( + "Error decoding JSON tool call from response.") + tool_args = msg_text + else: + tool_args = msg_text tool_calls.append( ToolCall( type="function", function=FunctionCall( name=msg.recipient.split("functions.")[1], - arguments=msg.content[0].text, + arguments=tool_args, ), )) elif msg.channel == "final": - final_content = msg.content[0].text + final_content = msg_text return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, diff --git a/vllm/envs.py b/vllm/envs.py index 50d58c5468f97..689428ec59109 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -119,7 +119,7 @@ if TYPE_CHECKING: VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False - VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16 + VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None @@ -187,12 +187,15 @@ if TYPE_CHECKING: VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None + VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True + VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True VLLM_USE_NCCL_SYMM_MEM: bool = False VLLM_NCCL_INCLUDE_PATH: Optional[str] = None @@ -1014,7 +1017,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # max number splits for cuda graph decode "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", - "16")), + "32")), # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, @@ -1385,6 +1388,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), + # Add optional nvtx scopes for profiling, disable to avoid overheads + "VLLM_NVTX_SCOPES_FOR_PROFILING": + lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))), + # Represent block hashes in KV cache events as 64-bit integers instead of # raw bytes. Defaults to True for backward compatibility. "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": @@ -1413,6 +1420,17 @@ environment_variables: dict[str, Callable[[], Any]] = { "code_interpreter", "web_search_preview"]), + # Enable max_autotune & coordinate_descent_tuning in inductor_config + # to compile static shapes passed from compile_sizes in compilation_config + # If set to 1, enable max_autotune; By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE": + lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))), + # If set to 1, enable coordinate_descent_tuning; + # By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING": + lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "1"))), + # Flag to enable NCCL symmetric memory allocation and registration "VLLM_USE_NCCL_SYMM_MEM": lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))), @@ -1513,6 +1531,8 @@ def compute_hash() -> str: "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 46f49aaa013da..3f1cac531f450 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import InputContext, InputProcessingContext __all__ = [ "DataPrompt", @@ -28,6 +27,4 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "InputContext", - "InputProcessingContext", ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py deleted file mode 100644 index b5316b6d0574c..0000000000000 --- a/vllm/inputs/registry.py +++ /dev/null @@ -1,186 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Union - -import torch -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar - -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.utils import get_allowed_kwarg_only_overrides -from vllm.utils.jsontree import JSONTree, json_map_leaves - -if TYPE_CHECKING: - from vllm.config import ModelConfig - from vllm.transformers_utils.tokenizer import AnyTokenizer -else: - ModelConfig = Any - AnyTokenizer = Any - -_T = TypeVar("_T") -_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) -_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class InputContext: - """ - Contains information about the model which may be used to - modify the inputs. - """ - - model_config: ModelConfig - """The configuration of the model.""" - - def get_hf_config( - self, - typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, - /, - ) -> _C: - """ - Get the HuggingFace configuration - (`transformers.PretrainedConfig`) of the model, - additionally checking its type. - - Raises: - TypeError: If the configuration is not of the specified type. - """ - hf_config = self.model_config.hf_config - if not isinstance(hf_config, typ): - raise TypeError("Invalid type of HuggingFace config. " - f"Expected type: {typ}, but " - f"found type: {type(hf_config)}") - - return hf_config - - def get_hf_image_processor_config(self) -> dict[str, Any]: - """ - Get the HuggingFace image processor configuration of the model. - """ - return self.model_config.hf_image_processor_config - - def get_mm_config(self): - """ - Get the multimodal config of the model. - - Raises: - RuntimeError: If the model is not a multimodal model. - """ - mm_config = self.model_config.multimodal_config - if mm_config is None: - raise RuntimeError("Not a multimodal model") - - return mm_config - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - """ - Get the HuggingFace processor - (`transformers.ProcessorMixin`) of the model, - additionally checking its type. - - Raises: - TypeError: If the processor is not of the specified type. - """ - return cached_processor_from_config( - self.model_config, - processor_cls=typ, - **kwargs, - ) - - def init_processor( - self, - typ: type[_T], - /, - **kwargs: object, - ) -> _T: - """ - Initialize a HuggingFace-like processor class, merging the - keyword arguments with those in the model's configuration. - """ - mm_config = self.model_config.get_multimodal_config() - base_kwargs = mm_config.mm_processor_kwargs - if base_kwargs is None: - base_kwargs = {} - - merged_kwargs = {**base_kwargs, **kwargs} - - return typ(**merged_kwargs) - - -@dataclass(frozen=True) -class InputProcessingContext(InputContext): - tokenizer: AnyTokenizer - """The tokenizer used to tokenize the inputs.""" - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - return super().get_hf_processor( - typ, - tokenizer=self.tokenizer, - **kwargs, - ) - - def call_hf_processor( - self, - hf_processor: ProcessorMixin, - data: Mapping[str, object], - kwargs: Mapping[str, object] = {}, - ) -> Union[BatchFeature, JSONTree]: - """ - Call `hf_processor` on the prompt `data` - (text, image, audio...) with configurable options `kwargs`. - """ - assert callable(hf_processor) - - mm_config = self.model_config.get_multimodal_config() - merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) - - allowed_kwargs = get_allowed_kwarg_only_overrides( - hf_processor, - merged_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - def maybe_cast_dtype(x): - # This mimics the behavior of transformers.BatchFeature - if isinstance(x, torch.Tensor) and x.is_floating_point(): - return x.to(dtype=self.model_config.dtype) - return x - - try: - output = hf_processor(**data, - **allowed_kwargs, - return_tensors="pt") - # this emulates output.to(dtype=self.model_config.dtype) - if isinstance(output, BatchFeature): - cast_output = json_map_leaves(maybe_cast_dtype, output.data) - return BatchFeature(cast_output) - - cast_output = json_map_leaves(maybe_cast_dtype, output) - - logger.warning_once( - f"{type(hf_processor).__name__} did not return `BatchFeature`. " - "Make sure to match the behaviour of `ProcessorMixin` when " - "implementing custom processors.") - return cast_output - - except Exception as exc: - msg = (f"Failed to apply {type(hf_processor).__name__} " - f"on data={data} with kwargs={allowed_kwargs}") - - raise ValueError(msg) from exc diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index b1ab84e08ba76..467cbaa8af48f 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -11,7 +11,6 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -283,7 +282,6 @@ try: op_func=_lora_expand, mutates_args=["output_tensor"], fake_impl=_lora_expand_fake, - dispatch_key=current_platform.dispatch_key, ) lora_expand = torch.ops.vllm.lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 1e7075ab07151..57da93c226d25 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -11,7 +11,6 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -237,7 +236,6 @@ try: op_func=_lora_shrink, mutates_args=["output_tensor"], fake_impl=_lora_shrink_fake, - dispatch_key=current_platform.dispatch_key, ) lora_shrink = torch.ops.vllm.lora_shrink diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index b278e37415748..98437340fd242 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ssm_state_indices, num_accepted_tokens, scale, - N: tl.constexpr, # num of sequences - T: tl.constexpr, # num of tokens + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens B: tl.constexpr, H: tl.constexpr, HV: tl.constexpr, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..600bd4444535a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.4.0", + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 0eec93601b3f2..114f349538fbe 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -98,13 +98,16 @@ def select_experts( e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + topk_logit_vals, topk_idx = torch.topk(router_logits, + k=top_k, + dim=-1, + sorted=False) if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids.to(torch.int32) + topk_vals = torch.softmax(topk_logit_vals, dim=-1) + else: + logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True) + topk_vals = (topk_logit_vals - logZ).exp() + return topk_vals.to(torch.float32), topk_idx.to(torch.int32) else: return custom_routing_function(hidden_states=hidden_states, gating_output=router_logits, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index e358143fac7c7..fe586a22e2506 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake( direct_register_custom_op( op_name="flashinfer_fused_moe_blockscale_fp8", op_func=flashinfer_fused_moe_blockscale_fp8, - mutates_args=[], fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1e3ac6cd79f68..eb12a9b0a233f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, direct_register_custom_op( op_name="fused_marlin_moe", op_func=fused_marlin_moe, - mutates_args=[], fake_impl=fused_marlin_moe_fake, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0e334fdf24045..611df357265bf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake( direct_register_custom_op( op_name="outplace_fused_experts", op_func=outplace_fused_experts, - mutates_args=[], fake_impl=outplace_fused_experts_fake, tags=(() if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order, )), diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 0e84a9241e905..18de758519346 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -23,7 +23,7 @@ if has_triton_kernels(): from triton_kernels.routing import (RoutingData, routing, routing_from_bitmatrix) from triton_kernels.tensor import Bitmatrix - except (ModuleNotFoundError, AttributeError) as e: + except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " "version is compatible. Error: %s", e) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 71cc2bcf174dd..89e0cee08170b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -69,8 +69,6 @@ else: if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) -elif current_platform.is_cpu(): - pass else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): @@ -2040,7 +2038,6 @@ direct_register_custom_op( op_func=moe_forward, mutates_args=["hidden_states"], fake_impl=moe_forward_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) @@ -2071,7 +2068,6 @@ direct_register_custom_op( op_func=moe_forward_shared, mutates_args=["hidden_states"], fake_impl=moe_forward_shared_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index f4972ff5f9cb0..2764af5fc5323 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -223,17 +223,13 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=rocm_aiter_asm_moe_tkw1_impl, - mutates_args=[], fake_impl=rocm_aiter_asm_moe_tkw1_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=rocm_aiter_fused_moe_impl, - mutates_args=[], fake_impl=rocm_aiter_fused_moe_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -241,7 +237,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_topk_softmax_impl, mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], fake_impl=rocm_aiter_topk_softmax_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -249,7 +244,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_biased_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_biased_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -257,7 +251,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index f875f712ba9c9..8123259d037ba 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -103,17 +103,13 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=rocm_aiter_rms_norm_impl, - mutates_args=[], fake_impl=rocm_aiter_rms_norm_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - mutates_args=[], fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5bf96398bc710..df5bced6b2288 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, + ModelWeightParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, @@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ + "UnquantizedLinearMethod", "CompressedTensorsLinearMethod", "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", @@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase): # The amount of memory allocated for the weights is # sum(output_partition_sizes) * input_size_per_partition. try: - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + weight_loader = extra_weight_attrs.pop("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) except torch.cuda.OutOfMemoryError as e: logger.error("Failed to create unquantized linear weights: %s", e) if torch.cuda.is_available(): @@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase): "Failed to create unquantized linear weights. " "This may be caused by insufficient memory to allocate " "the weight.") from e - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 6a901b47b8b63..410cbef4f6bc0 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata @@ -401,5 +400,4 @@ direct_register_custom_op( op_func=linear_attention, mutates_args=["output"], fake_impl=linear_attention_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a56ee13a63804..d64854cdb3818 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -464,5 +463,4 @@ direct_register_custom_op( op_func=mamba_mixer, mutates_args=["output"], fake_impl=mamba_mixer_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 047ce4c4c43d0..908ea6e0025f1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -765,5 +764,4 @@ direct_register_custom_op( op_func=mamba_mixer2, mutates_args=["output"], fake_impl=mamba_mixer2_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index a7b3c814859ce..2e657426143b1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -17,7 +17,6 @@ from .mamba_ssm import softplus @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index ffdcd702aab40..cc424760e229f 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ( ShortConvAttentionMetadata) @@ -251,5 +250,4 @@ direct_register_custom_op( op_func=short_conv, mutates_args=["output"], fake_impl=short_conv_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 10f9085be4d12..a7d3e920414d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config, int8_w8a16_moe_quant_config, nvfp4_moe_quant_config) +from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa @@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -63,7 +64,7 @@ __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod" + "CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod" ] @@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) + elif quant_config._is_dynamic_token_w4a8_int(weight_quant, + input_quant): + return CompressedTensorsW4A8Int8MoEMethod(quant_config, + layer.moe_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_map=expert_map, quant_config=self.moe_quant_config, ) + + +class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): + """ + CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform + - Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles) + - Scales: Fp32 for Channelwise , bf16 for groupwise quantization + - Bias: Same data type as original weights + - Activations: FP32/Bf16 dynamic per-token (A8 Int), + quantized inside the kernel + """ + + def __init__( + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.quant_config = quant_config + + # Validate scheme: weights=W4 (channel or group), + # activations=dynamic TOKEN (A8) + wq = self.quant_config.target_scheme_map["Linear"].get("weights") + aq = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + # Must be dynamic per-token activations + if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: + raise ValueError( + "W4A8-int MoE needs dynamic per-token activation quantization." + ) + + # Weight can be channel-wise (group_size=None) or group-wise + self.group_size = wq.group_size if (wq.group_size is not None) else -1 + if wq.num_bits != 4: + raise ValueError( + "This method only supports 4-bit weights (num_bits=4).") + + # CPU only + if not current_platform.is_cpu(): + raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.") + + # Arm: check _dyn ops availability + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + try: + _ = torch.ops.aten._dyn_quant_matmul_4bit + _ = torch.ops.aten._dyn_quant_pack_4bit_weight + except AttributeError as err: + raise RuntimeError( + f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; + install a newer build.""") from err + self.static_input_scales = False # always dynamic per token + + # ---- parameter creation ---- + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Shapes per local rank (TP/EP): + # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) + # w2 : [E, H, I_local] int8 + # Scales: + # channel-wise: group_size=-1 -> per-output-row, single scale per row + # group-wise : group_size=g -> + # per-output-row, (in_features/g) scales + + E = num_experts + H = hidden_size + IN = intermediate_size_per_partition + g = self.group_size + + # Per-row scale columns + def _n_scale_cols(in_features: int) -> int: + return 1 if g == -1 else (in_features // g) + + # Register unpacked int4-as-int8 weights the loader will fill. + w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8), + requires_grad=False) + set_weight_attrs(w13, extra_weight_attrs) + layer.register_parameter("w13_weight", w13) + + w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8), + requires_grad=False) + set_weight_attrs(w2, extra_weight_attrs) + layer.register_parameter("w2_weight", w2) + + # Register scales + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + + w13_s = torch.nn.Parameter(torch.ones(E, + 2 * IN, + _n_scale_cols(H), + dtype=scale_dtype), + requires_grad=False) + set_weight_attrs( + w13_s, { + "quant_method": "channel" if g == -1 else "group", + **extra_weight_attrs + }) + layer.register_parameter("w13_weight_scale", w13_s) + + w2_s = torch.nn.Parameter(torch.ones(E, + H, + _n_scale_cols(IN), + dtype=scale_dtype), + requires_grad=False) + set_weight_attrs( + w2_s, { + "quant_method": "channel" if g == -1 else "group", + **extra_weight_attrs + }) + layer.register_parameter("w2_weight_scale", w2_s) + + if self.has_bias: + w13_bias = torch.nn.Parameter(torch.zeros(E, + 2 * IN, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + # Placeholders for packed weights (will be replaced after packing) + layer.register_parameter( + "w13_weight_packed", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) + + layer.register_parameter( + "w2_weight_packed", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) + + # dims for 4 bit fused matmuls + layer.w13_in_features = H + layer.w13_out_features = 2 * IN + layer.w2_in_features = IN + layer.w2_out_features = H + layer.group_size = g + + # post-load packing to dyn-4bit KleidiAI kernel's format + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E = layer.w13_weight.shape[0] + H = layer.w13_in_features + I2 = layer.w13_out_features + IN = layer.w2_in_features + g = layer.group_size + + def _pack_matrix(int4_as_int8_2d: torch.Tensor, + scales_2d: torch.Tensor, + bias_1d: Optional[torch.Tensor], in_features: int, + out_features: int) -> torch.Tensor: + # int4 values are stored as int8 in [-8,7]. + # Shift to unsigned nibble and pack pairs along input-dim. + tmp = int4_as_int8_2d.add(8) # [out, in] + uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( + torch.uint8) # [out, in//2] + + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + scales = scales_2d.to(scale_dtype) + bias = None if bias_1d is None else bias_1d.to(torch.float32) + return torch.ops.aten._dyn_quant_pack_4bit_weight( + uint8_nibbles, scales, bias, g if g != -1 else in_features, + in_features, out_features) + + # Pack per expert + w13_packed_list = [] + w2_packed_list = [] + + has_w13_bias = hasattr(layer, + "w13_bias") and layer.w13_bias is not None + has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None + + for e in range(E): + w13_packed_list.append( + _pack_matrix( + layer.w13_weight[e], # [2I, H] + layer.w13_weight_scale[e], # [2I, H/g or 1] + layer.w13_bias[e] if has_w13_bias else None, # [2I] + H, + I2)) + w2_packed_list.append( + _pack_matrix( + # w2 shape is [H, IN]; we need [out, in] == [H, IN]. + layer.w2_weight[e], # [H, IN] + layer.w2_weight_scale[e], # [H, IN/g or 1] + layer.w2_bias[e] if has_w2_bias else None, # [H] + IN, + layer.w2_out_features # in_features=IN, out_features=H + )) + + # each packed tensor has identical shape per expert; stack on dim 0 + w13_packed = torch.stack(w13_packed_list, dim=0) + w2_packed = torch.stack(w2_packed_list, dim=0) + + replace_parameter(layer, "w13_weight_packed", + torch.nn.Parameter(w13_packed, requires_grad=False)) + replace_parameter(layer, "w2_weight_packed", + torch.nn.Parameter(w2_packed, requires_grad=False)) + + # free raw tensors/scales/bias now that they're packed into the payload. + replace_parameter( + layer, "w13_weight", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w2_weight", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w13_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w2_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + if has_w13_bias: + replace_parameter( + layer, "w13_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + if has_w2_bias: + replace_parameter( + layer, "w2_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + # CPU dynamic 4-bit MoE path does not use modular kernels or + # fused_experts; quant config is not needed. + return None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." + assert activation in ( + "silu", "swigluoai", + "swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + assert expert_map is None, """expert_map/EP not implemented + for CPU dyn-4bit MoE.""" + + def _act_kind(s: str) -> int: + # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU + if s == "swiglu": + return 0 + if s == "swigluoai": + return 1 + if s == "silu": + return 2 + raise ValueError(f"Unknown activation '{s}'") + + # Apply topk softmax on router output + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops._C.dynamic_4bit_int_moe( + x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed, + layer.w2_weight_packed, layer.w2_out_features, + layer.w2_in_features, layer.w13_out_features, layer.group_size, + apply_router_weight_on_input, int(_act_kind(activation))) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index c2b3ccf19fca8..8452f686b3acc 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -4,7 +4,6 @@ import logging import torch -from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import fp8_gemm_nt @@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake( direct_register_custom_op( op_name="w8a8_deepgemm_block_scaled_mm", op_func=w8a8_deepgemm_block_scaled_mm, - mutates_args=[], fake_impl=w8a8_deepgemm_block_scaled_mm_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a631dfdab6544..de25ee84d081e 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -161,7 +161,6 @@ try: direct_register_custom_op( op_name="_fused_mul_mat_gguf", op_func=_fused_mul_mat_gguf, - mutates_args=[], fake_impl=_fused_mul_mat_gguf_fake, ) fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf @@ -273,7 +272,6 @@ try: direct_register_custom_op( op_name="_fused_moe_gguf", op_func=_fused_moe_gguf, - mutates_args=[], fake_impl=_fused_moe_gguf_fake, ) fused_moe_gguf = torch.ops.vllm._fused_moe_gguf @@ -319,7 +317,6 @@ try: direct_register_custom_op( op_name="_apply_gguf_embedding", op_func=_apply_gguf_embedding, - mutates_args=[], fake_impl=_apply_gguf_embedding_fake, ) apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 7f808fa92a9a8..e8e950a4bb7b6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -51,9 +51,7 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8", op_func=rocm_aiter_gemm_w8a8_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a71c8d32a22c7..b710f6ee249b1 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif current_platform.is_rocm() or ( - self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) hidden_size = round_up(hidden_size, 128) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2098086bf2401..a4cfc7d6c15c6 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -91,9 +91,7 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8_blockscale", op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, ) if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR and current_platform.is_fp8_fnuz()): @@ -132,13 +130,14 @@ def _w8a8_triton_block_scaled_mm_fake( device=qx.device) -direct_register_custom_op( - "w8a8_triton_block_scaled_mm_func", - _w8a8_triton_block_scaled_mm_func, - mutates_args=[], - fake_impl=_w8a8_triton_block_scaled_mm_fake, - dispatch_key="CUDA", -) +# Note: the check can be removed when CPU torch > 2.7 +if not current_platform.is_cpu(): + direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + fake_impl=_w8a8_triton_block_scaled_mm_fake, + dispatch_key="CUDA", + ) # TODO fix ROCm->Triton custom path: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 3de928fea7202..fb1d041f34499 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.tensor_details.layout import StridedLayout + + value_layout_opts: dict[str, Any] = {} + scale_layout_opts: dict[str, Any] = {} + if (current_platform.is_cuda() and current_platform.is_device_capability(90) and not is_torch_equal_or_newer("2.8.1")): @@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "Mxfp4 on hopper is running on torch < 2.8.1, " "this cause swizling to be disabled, which may " "cause performance degradation. Please upgrade to torch nightly") - value_layout, value_layout_opts = StridedLayout, dict() - scale_layout, scale_layout_opts = StridedLayout, dict() + value_layout = StridedLayout + scale_layout = StridedLayout + elif current_platform.is_rocm(): + from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout, + StridedLayout) + + from vllm.platforms.rocm import on_gfx950 + value_layout = StridedLayout + scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout else: value_layout, value_layout_opts = \ layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) @@ -113,7 +124,6 @@ try: direct_register_custom_op( op_name="dequant_mxfp4", op_func=_dequant_mxfp4, - mutates_args=[], fake_impl=_dequant_mxfp4_fake, ) dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 @@ -124,7 +134,6 @@ try: direct_register_custom_op( op_name="quant_dequant_mxfp4", op_func=_quant_dequant_mxfp4, - mutates_args=[], fake_impl=_quant_dequant_mxfp4_fake, ) quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 6ed482db4700e..b434b7acfea83 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, direct_register_custom_op( op_name="rocm_per_tensor_w8a8_scaled_mm_impl", op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - mutates_args=[], fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index e3cd0a8e788eb..8619651067746 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -147,5 +147,4 @@ direct_register_custom_op( op_func=_flashinfer_rotary_embedding, mutates_args=["query", "key"], # These tensors are modified in-place fake_impl=_flashinfer_rotary_embedding_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d7a65d43c2107..96dd58c0e4d23 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, direct_register_custom_op( op_name="rocm_unquantized_gemm_impl", op_func=rocm_unquantized_gemm_impl, - mutates_args=[], fake_impl=rocm_unquantized_gemm_impl_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 0f05f9b4efcd6..6fd8c2fb5c561 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor, - **kwargs) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype), - **kwargs) + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + **kwargs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + target_dtype: torch.dtype = \ + vision_tower.get_input_embeddings().weight.dtype + image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values.to(dtype=target_dtype), **kwargs) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index aa7bcf5b65ada..cab85ea347f4c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): } -class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig): - - @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config - config.max_seq_len_to_capture = config.max_model_len - logger.info( - "Setting max_seq_len_to_capture to %d " - "to ensure that CUDA graph capture " - "covers sequences of length up to max_model_len.", - config.max_model_len) - - class GptOssForCausalLMConfig(VerifyAndUpdateConfig): @staticmethod @@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig, - "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, "GptOssForCausalLM": GptOssForCausalLMConfig, "MambaForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 415d36c681d8c..9895ebbcdefee 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op @@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: direct_register_custom_op( op_name="sequence_parallel_chunk", op_func=sequence_parallel_chunk, - mutates_args=[], fake_impl=sequence_parallel_chunk_fake, - dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index f4d288fd887e9..0b6bccb334982 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -26,6 +26,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, GeluAndMul, @@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, @@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index, logger = init_logger(__name__) +EPS = torch.tensor(torch.finfo().min) + class Gemma3nAltUp(nn.Module): """Alternating updates (Altup) @@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module): return corrected_predictions -@support_torch_compile -class Gemma3nTextModel(nn.Module, SupportsQuant): +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nSelfDecoder(nn.Module): + """ + Includes altup embedding and self decoder layers + """ - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config + quant_config = vllm_config.quant_config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): prefix=f"{prefix}.altup_projections.{idx-1}", ) for idx in range(1, self.config.altup_num_inputs) ]) - self.altup_unembed_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_unembed_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) - - # Transformer blocks. - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Gemma3nDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.eps = torch.tensor(torch.finfo().min) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale def get_per_layer_input_embeddings( self, input_ids: torch.Tensor) -> torch.Tensor: @@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): return self.embed_tokens_per_layer( per_layer_inputs_tokens) * self.embed_scale_per_layer - def forward( + def get_per_layer_inputs( self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - per_layer_inputs: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) - + hidden_states_0: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( *hidden_states_0.shape[:-1], @@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ) per_layer_projection = self.per_layer_projection_norm( per_layer_projection) - if per_layer_inputs is not None: # Profiling run does not compute per_layer_inputs per_layer_inputs = per_layer_projection + per_layer_inputs per_layer_inputs *= self.per_layer_input_scale else: per_layer_inputs = per_layer_projection + return per_layer_inputs + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: # Altup embed. hidden_states = [hidden_states_0] * self.config.altup_num_inputs target_magnitude = torch.mean(hidden_states_0**2, dim=-1, @@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): dim=-1, keepdim=True)**0.5 hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=0) + new_magnitude, EPS) + hidden_states = torch.stack(hidden_states, dim=-1) + return hidden_states - # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + adjusted_per_layer_inputs = self.get_per_layer_inputs( + hidden_states_0, per_layer_inputs) + hidden_states = self.altup_embed(hidden_states_0) + + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + + return hidden_states, adjusted_per_layer_inputs + + +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nCrossDecoder(nn.Module): + """ + Cross-decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start # [altup_num_inputs, num_tokens, hidden_size] hidden_states = layer( positions=positions, @@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): per_layer_input=per_layer_inputs[:, layer_idx, :], **kwargs, ) + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + +# This disables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) +class Gemma3nTextModel(nn.Module, SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.altup_unembed_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_unembed_projections.{idx-1}", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + + # Allocate config.num_kv_shared_layers layers for self-decoder + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3nDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + + # NOTE(sarckk): importing this top level seems to cause issues + # during running of tests. + from vllm.compilation.backends import set_model_tag + + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + with set_model_tag("self_decoder"): + self.self_decoder = Gemma3nSelfDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + ) + # Layer idx 20-30 are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma3nCrossDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + # TODO(sarckk): Extract this functionality to interface + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros(max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size, + self.config.altup_num_inputs), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + self.per_layer_inputs = torch.zeros( + (max_num_tokens, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + + @property + def embed_tokens(self): + return self.self_decoder.embed_tokens + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_per_layer_input_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_input_embeddings(input_ids) + + def fast_prefill_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + # attn_metadata is None during dummy runs + if (self.fast_prefill_enabled and attn_metadata is not None): + assert isinstance(attn_metadata, dict) + # Last layer is a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name] + if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)): + logits_indices_padded = ( + layer_attn_metadata.logits_indices_padded) + num_logits_indices = layer_attn_metadata.num_logits_indices + + # Copy inputs for cudagraph + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs_adjusted = \ + self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + positions.size(0), + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE(sarckk): There is currently a bug caused by + # vLLM converting output of last piecewise CUDA graph + # to weakref, causing memory to be prematurely freed + # when there are multiple compilation units + # Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + # Copy inputs for cudagraph + num_padded_logits_indices = logits_indices_padded.size(0) + self.positions[:num_padded_logits_indices].copy_( + positions[logits_indices_padded]) + self.hidden_states[:num_padded_logits_indices].copy_( + self_decoder_hidden_states[logits_indices_padded]) + self.per_layer_inputs[:num_padded_logits_indices].copy_( + per_layer_inputs_adjusted[logits_indices_padded]) + cross_decoder_hidden_states = self.cross_decoder( + positions=self.positions[:num_padded_logits_indices], + hidden_states=self.hidden_states[:num_padded_logits_indices], + per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], + **kwargs, + ) + + if num_logits_indices is not None: + assert num_logits_indices > 0 + # Merge cross-decoder and self-decoder hidden states + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_decoder_hidden_states[:num_logits_indices]) + else: + hidden_states = cross_decoder_hidden_states + + return hidden_states + + def normal_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + hidden_states = self.cross_decoder( + positions=positions, + hidden_states=hidden_states, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + return hidden_states + + def altup_unembed( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[0]**2, + target_magnitude = torch.mean(hidden_states[..., 0]**2, dim=-1, keepdim=True)**0.5 for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_unembed_projections[i - 1]( - hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, + hidden_states[..., i] = self.altup_unembed_projections[i - 1]( + hidden_states[..., i]) + new_magnitude = torch.mean(hidden_states[..., i]**2, dim=-1, keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=0) + hidden_states[..., i] *= target_magnitude / torch.maximum( + new_magnitude, EPS) + # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=-1) + return hidden_states + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + else: + hidden_states = self.normal_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) def load_weights(self, weights: Iterable[tuple[str, @@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # decoder layer weights, altup_unembed_projections and rmsnorm + # are initialized in text model, others are in self decoder + if (not name.startswith('layers') + and not name.startswith('altup_unembed_projections') + and not name.startswith('norm')): + name = f"self_decoder.{name}" + if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache scales for compressed-tensors quantization diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 795b38e724eab..2c619396e6c0c 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -308,13 +308,11 @@ class GraniteModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) - residual = None hidden_states *= self.config.embedding_multiplier else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) @@ -322,7 +320,6 @@ class GraniteModel(nn.Module): if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, - "residual": residual }) hidden_states = self.norm(hidden_states) @@ -475,10 +472,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): torch.zeros((batch_size, self.config.hidden_size), dtype=dtype, device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), }) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 07200fef4799d..47ac22c4aeaa5 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -298,17 +298,14 @@ class GraniteMoeModel(nn.Module): else: hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, - "residual": residual }) hidden_states = self.norm(hidden_states) return hidden_states @@ -523,10 +520,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): torch.zeros((batch_size, self.config.hidden_size), dtype=dtype, device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), }) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index a5d118f084e6c..b434822bff0a9 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -195,17 +195,14 @@ class GraniteMoeSharedModel(nn.Module): else: hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, - "residual": residual }) hidden_states = self.norm(hidden_states) return hidden_states @@ -323,10 +320,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): torch.zeros((batch_size, self.config.hidden_size), dtype=dtype, device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), }) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 54167f9f10995..4d39ff9ae79ee 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache @@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e2d7b9f23b28a..4d8ed95b6cc8f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union, cast) + Union) import torch import torch.nn as nn @@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -622,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -630,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) def _process_image_pixels( self, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index b2f020f3323e8..d81ac8c704e79 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -254,7 +254,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = tuple(vision_tower(p) for p in pixel_values) + image_features: tuple[torch.Tensor, ...] = \ + tuple(vision_tower(p) for p in pixel_values) def select_features(leaf: torch.Tensor): return self._select_image_features( @@ -262,10 +263,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), - ) + return json_map_leaves(select_features, image_features) # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 def pack_image_features(self, image_features: list[torch.Tensor], diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 94e3d7234b6f4..ba6da4403ae16 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -13,7 +13,6 @@ from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig, from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -27,8 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 50521b5937862..79e315f794893 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -32,7 +32,6 @@ from transformers.models.llama4.image_processing_llama4_fast import ( from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -47,8 +46,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 0292f3bf8317d..a7acf64f302bc 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -48,7 +48,6 @@ from vllm.model_executor.models.utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -490,7 +489,6 @@ direct_register_custom_op( op_func=plamo2_mamba_mixer, mutates_args=["output"], fake_impl=plamo2_mamba_mixer_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ab23b494e561e..356b5001a7dc8 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1225,7 +1225,6 @@ direct_register_custom_op( op_func=gdn_attention, mutates_args=["output"], fake_impl=gdn_attention_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6ab3fa902c387..ac0ec6ca146c9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -483,23 +483,23 @@ class _LazyRegisteredModel(_BaseRegisteredModel): def inspect_model_cls(self) -> _ModelInfo: model_path = Path( __file__).parent / f"{self.module_name.split('.')[-1]}.py" + module_hash = None - assert model_path.exists(), \ - f"Model {self.module_name} expected to be on path {model_path}" - with open(model_path, "rb") as f: - module_hash = hashlib.md5(f.read()).hexdigest() + if model_path.exists(): + with open(model_path, "rb") as f: + module_hash = hashlib.md5(f.read()).hexdigest() - mi = self._load_modelinfo_from_cache(module_hash) - if mi is not None: - logger.debug(("Loaded model info " - "for class %s.%s from cache"), self.module_name, - self.class_name) - return mi - else: - logger.debug(("Cache model info " - "for class %s.%s miss. " - "Loading model instead."), self.module_name, - self.class_name) + mi = self._load_modelinfo_from_cache(module_hash) + if mi is not None: + logger.debug(("Loaded model info " + "for class %s.%s from cache"), self.module_name, + self.class_name) + return mi + else: + logger.debug(("Cache model info " + "for class %s.%s miss. " + "Loading model instead."), self.module_name, + self.class_name) # Performed in another process to avoid initializing CUDA mi = _run_in_subprocess( @@ -508,7 +508,8 @@ class _LazyRegisteredModel(_BaseRegisteredModel): self.class_name) # save cache file - self._save_modelinfo_to_cache(mi, module_hash) + if module_hash is not None: + self._save_modelinfo_to_cache(mi, module_hash) return mi diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 67cf3ccf315d1..3660efdc079aa 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union, cast) + Union) import torch import torch.nn as nn @@ -17,7 +17,6 @@ from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -29,8 +28,9 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -490,11 +490,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # From vLLM LLaVA, vision tower output handling - image_hidden_states = vision_tower(pixel_values) - if not isinstance(image_hidden_states, torch.Tensor): - raise TypeError( - f"image_hidden_states type: {type(image_hidden_states)}" - " is not supported") + image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ + vision_tower(pixel_values) def select_features_fn(leaf: torch.Tensor): return self._select_image_features( @@ -502,11 +499,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, strategy=self.config.vision_feature_select_strategy, ) - selected_features = cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features_fn, image_hidden_states), - ) - return selected_features + return json_map_leaves(select_features_fn, image_hidden_states) def _add_tarsier_split_tokens( self, projected_image_features: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 03e5e5809b678..66add98dab443 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter): self.tp_size = get_tensor_model_parallel_world_size() @property - def weight_loader(self): + def weight_loader(self) -> Callable: + # NOTE(@ksayers) some models such as mamba_mixer2 override the + # weight loader to support custom loading. In the future, model-specific + # weight loading should be implemented via Model.load_weights. In the + # meantime, support deleting and overriding `weight_loader`` attribute + if self._weight_loader is None: + raise AttributeError(f"{self.__class__.__name__} weight_loader " + "attribute has been deleted") return self._weight_loader + @weight_loader.setter + def weight_loader(self, value: Callable): + self._weight_loader = value + + @weight_loader.deleter + def weight_loader(self): + self._weight_loader = None # type: ignore[assignment] + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): cond1 = self.data.ndim == 1 and self.data.numel() == 1 cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 @@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter): assert shard_id in qkv_idxs return qkv_idxs[shard_id] + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + class _ColumnvLLMParameter(BasevLLMParameter): """ diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index f6df85a50238c..e495f9ee44724 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -53,9 +53,11 @@ def _extract_data_from_fused_moe_module( """ assert isinstance(m, FusedMoE) w13 = m.w13_weight - w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale) + w13_s = m.w13_weight_scale_inv if hasattr( + m, "w13_weight_scale_inv") else m.w13_weight_scale w2 = m.w2_weight - w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale) + w2_s = m.w2_weight_scale_inv if hasattr( + m, "w2_weight_scale_inv") else m.w2_weight_scale num_topk = m.top_k assert isinstance(w13, torch.Tensor) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7471bfcb4d508..78e2cb7fa7334 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, @@ -7,18 +8,20 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache -from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union, cast) +from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, + Protocol, Union, cast, overload) import regex as re import torch -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never -from vllm.inputs import InputProcessingContext from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import flatten_2d_lists, full_groupby +from vllm.utils import (flatten_2d_lists, full_groupby, + get_allowed_kwarg_only_overrides) +from vllm.utils.jsontree import JSONTree, json_map_leaves from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, @@ -34,6 +37,8 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from vllm.config import ModelConfig + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder @@ -875,6 +880,222 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) +_T = TypeVar("_T") +_C = TypeVar("_C", bound="PretrainedConfig", default="PretrainedConfig") +_P = TypeVar("_P", bound="ProcessorMixin", default="ProcessorMixin") + + +@dataclass(frozen=True) +class InputProcessingContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + @overload + def get_hf_config(self, /) -> "PretrainedConfig": + ... + + @overload + def get_hf_config( + self, + typ: Union[type[_C], tuple[type[_C], ...]], + /, + ) -> _C: + ... + + def get_hf_config( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + ) -> Any: + """ + Get the HuggingFace configuration + (`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the configuration is not of the specified type. + """ + if typ is None: + from transformers.configuration_utils import PretrainedConfig + + typ = PretrainedConfig + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, typ): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}") + + return hf_config + + def get_hf_image_processor_config(self) -> dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + return self.model_config.hf_image_processor_config + + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + @overload + def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": + ... + + @overload + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]], + /, + **kwargs: object, + ) -> _P: + ... + + def get_hf_processor( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + **kwargs: object, + ) -> Any: + """ + Get the HuggingFace processor + (`transformers.ProcessorMixin`) of the model, + additionally checking its type. + + Raises: + TypeError: If the processor is not of the specified type. + """ + if typ is None: + from transformers.processing_utils import ProcessorMixin + + typ = ProcessorMixin + + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + tokenizer=self.tokenizer, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ + mm_config = self.model_config.get_multimodal_config() + base_kwargs = mm_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return typ(**merged_kwargs) + + def _postprocess_output( + self, + output: JSONTree, + ) -> JSONTree: + + def _postprocess_one(x: object): + if isinstance(x, torch.Tensor): # noqa: SIM102 + # This mimics the behavior of transformers.BatchFeature + if x.is_floating_point(): + x = x.to(dtype=self.model_config.dtype) + + return x + + return json_map_leaves(_postprocess_one, output) + + def call_hf_processor( + self, + hf_processor: "ProcessorMixin", + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, + *, + num_tries: int = 1, + max_tries: int = 5, + ) -> Union["BatchFeature", JSONTree]: + """ + Call `hf_processor` on the prompt `data` + (text, image, audio...) with configurable options `kwargs`. + """ + assert callable(hf_processor) + + mm_config = self.model_config.get_multimodal_config() + merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) + + allowed_kwargs = get_allowed_kwarg_only_overrides( + hf_processor, + merged_kwargs, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + try: + output = hf_processor(**data, + **allowed_kwargs, + return_tensors="pt") + except Exception as exc: + # See https://github.com/huggingface/tokenizers/issues/537 + if (isinstance(exc, RuntimeError) and exc + and exc.args[0] == "Already borrowed" + and num_tries < max_tries): + logger.warning( + "Failed to acquire tokenizer in current thread. " + "Retrying (%d/%d)...", num_tries, max_tries) + time.sleep(0.5) + return self.call_hf_processor( + hf_processor, + data, + kwargs, + num_tries=num_tries + 1, + max_tries=max_tries, + ) + + msg = (f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={allowed_kwargs}") + + raise ValueError(msg) from exc + + # this emulates output.to(dtype=self.model_config.dtype) + from transformers.feature_extraction_utils import BatchFeature + + if isinstance(output, BatchFeature): + output_ = self._postprocess_output(output.data) + return BatchFeature(output_) + + logger.warning_once( + "%s did not return `BatchFeature`. " + "Make sure to match the behaviour of `ProcessorMixin` when " + "implementing custom processors.", + type(hf_processor).__name__, + ) + + return self._postprocess_output(output) + + class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5d485bc361d11..2bbc0078ad13a 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,14 +6,14 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn -from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry from .cache import BaseMultiModalProcessorCache -from .processing import BaseMultiModalProcessor, BaseProcessingInfo +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + InputProcessingContext) from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -41,7 +41,7 @@ class ProcessingInfoFactory(Protocol[_I_co]): ... -class DummyInputsBuilderFactory(Protocol[_I]): +class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] """ Constructs a [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder] diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 878718489fa88..942fd1973f4f3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -118,6 +118,12 @@ def on_gfx9() -> bool: return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) +@cache +def on_gfx950() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx950"]) + + @cache def use_rocm_custom_paged_attention( qtype: torch.dtype, diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index b987adeb6428f..3c8a9c6ae0d33 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .basic_parsers import BaseThinkingReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .gptoss_reasoning_parser import GptOssReasoningParser @@ -9,10 +10,12 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .seedoss_reasoning_parser import SeedOSSReasoningParser from .step3_reasoning_parser import Step3ReasoningParser __all__ = [ "ReasoningParser", + "BaseThinkingReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", @@ -22,4 +25,5 @@ __all__ = [ "MistralReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", + "SeedOSSReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index df9e84163f16c..39b08ec111073 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -7,7 +7,7 @@ import os from abc import abstractmethod from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Union from vllm.logger import init_logger from vllm.utils import import_from_path, is_list_of @@ -77,7 +77,7 @@ class ReasoningParser: self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest], - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Extract reasoning content from a complete model-generated string. @@ -135,7 +135,7 @@ class ReasoningParserManager: def _register_module( cls, module: type, - module_name: Optional[Union[str, list[str]]] = None, + module_name: Union[str, list[str]] | None = None, force: bool = True, ) -> None: if not issubclass(module, ReasoningParser): @@ -155,7 +155,7 @@ class ReasoningParserManager: @classmethod def register_module( cls, - name: Optional[Union[str, list[str]]] = None, + name: Union[str, list[str]] | None = None, force: bool = True, module: Union[type, None] = None, ) -> Union[type, Callable]: diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py new file mode 100644 index 0000000000000..03cb882c26939 --- /dev/null +++ b/vllm/reasoning/basic_parsers.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from collections.abc import Sequence +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ResponsesRequest) +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class BaseThinkingReasoningParser(ReasoningParser): + """ + Base class for reasoning parsers that use thinking tokens. + + This class provides common functionality for parsers that use start and end + tokens to delimit reasoning content ( + e.g., ..., ...). + + Subclasses must implement the start and end tokens via abstract + properties. + """ + + @property + @abstractmethod + def start_token(self) -> str: + """The token that starts reasoning content.""" + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + """The token that ends reasoning content.""" + raise NotImplementedError + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + if not self.start_token or not self.end_token: + raise ValueError( + "start_token and end_token must be defined in subclasses") + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} reasoning parser could not locate " + "think start/end tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.start_token_id, self.end_token_id + ]): + return None + + # Check if start token is present in previous or delta. + # Keep compatibility with models that don't generate start tokens. + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + # start token in previous, end token in delta, + # extract reasoning content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # start token in previous, end token in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # start token in previous, no end token in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + # start token in delta, end token in delta, + # extract reasoning content + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[start_index + + len(self.start_token):end_index] + content = delta_text[end_index + len(self.end_token):] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + else: + # start token in delta, no end token in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # not find thinking start token + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: Union[ChatCompletionRequest, + ResponsesRequest] + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + This is the base implementation that works for most models. + Subclasses can override this method for specific behavior. + """ + # Check if the start token is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + + # For models that may not generate start token, + # assume the reasoning content is always at the start. + if self.end_token not in model_output: + return model_output, None + else: + reasoning_content, _, content = model_output.partition( + self.end_token) + # If generation stops right after end-of-think, return null content + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f1d..76d2959e1c9a4 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -2,20 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union +from typing import Union -from transformers import PreTrainedTokenizerBase - -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) +from vllm.entrypoints.openai.protocol import DeltaMessage +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("deepseek_r1") -class DeepSeekR1ReasoningParser(ReasoningParser): +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for DeepSeek R1 model. @@ -23,38 +18,15 @@ class DeepSeekR1ReasoningParser(ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - start_token_id: int - end_token_id: int + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" - start_token: str = "" - end_token: str = "" - - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.start_token_id = self.vocab.get(self.start_token) - self.end_token_id = self.vocab.get(self.end_token) - if self.start_token_id is None or self.end_token_id is None: - raise RuntimeError( - "DeepSeek R1 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.end_token_id) + 1:] + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" def extract_reasoning_content_streaming( self, @@ -65,63 +37,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.start_token_id, self.end_token_id - ]): - return None - - # Check if is present in previous or delta. - # Keep compatibility with models that don't generate tokens. - if self.start_token_id in previous_token_ids: + ret = super().extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if (ret is not None and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids): if self.end_token_id in delta_token_ids: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - elif self.end_token_id in previous_token_ids: - # in previous, in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.start_token_id in delta_token_ids: - if self.end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.start_token) - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[start_index + - len(self.start_token):end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # No in previous or delta, also need to check for . - # Because the model may have generated without - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token_id in delta_token_ids: - # in delta with more tokens, + # end token in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] @@ -131,43 +58,10 @@ class DeepSeekR1ReasoningParser(ReasoningParser): content=content if content else None, ) elif self.end_token_id in previous_token_ids: - # in previous, thinking content ends + # end token in previous, thinking content ends return DeltaMessage(content=delta_text) else: - # no in previous or delta, reasoning content continues + # no end token in previous or delta, reasoning content continues return DeltaMessage(reasoning_content=delta_text) - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: - """ - Extract reasoning content from the model output. - - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - - Returns: - tuple[Optional[str], Optional[str]]: reasoning content and content - """ - - # Check if the start token is present in the model output, remove it - # if it is present. - model_output_parts = model_output.partition(self.start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] - - # DeepSeek R1 doesn't generate now. - # Thus we assume the reasoning content is always at the start. - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token not in model_output: - return model_output, None - else: - reasoning_content, _, content = model_output.partition( - self.end_token) - # If the end token is not found, return the model output as is. - # It should not happen since we already checked for the presence - # of the end token. - # If generation stops right after end-of-think, return null content - final_content = content or None - return reasoning_content, final_content + return ret diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index 6c707a4079fa0..5cb54e6acbb31 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cached_property + from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import ( @@ -31,11 +33,6 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - from mistral_common.tokens.tokenizers.base import SpecialTokens - - self.start_token = SpecialTokens.begin_think - self.end_token = SpecialTokens.end_think - self.start_token_id = tokenizer.tokenizer.get_control_token( self.start_token) self.end_token_id = tokenizer.tokenizer.get_control_token( @@ -45,3 +42,15 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): raise RuntimeError( "Mistral reasoning parser could not locate think start/end " "tokens in the tokenizer!") + + @cached_property + def start_token(self) -> str: + """The token that starts reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.begin_think + + @cached_property + def end_token(self) -> str: + """The token that ends reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.end_think diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 61bafc724c17f..3e3c7f32796bd 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -1,21 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence from typing import Optional, Union -from transformers import PreTrainedTokenizerBase - from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) + ResponsesRequest) +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("qwen3") -class Qwen3ReasoningParser(ReasoningParser): +class Qwen3ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for the Qwen3 model. @@ -26,100 +21,25 @@ class Qwen3ReasoningParser(ReasoningParser): output. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - self.think_start_token = "" - self.think_end_token = "" + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): - raise RuntimeError( - "Qwen3 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.think_end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text abcxyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id - ]): - return None - - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: - # in previous, in delta, - # extract reasoning content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: - # in previous, in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # in previous, no in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: - # in delta, in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - else: - # in delta, no in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # thinking is disabled, just content - return DeltaMessage(content=delta_text) + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: Union[ChatCompletionRequest, + ResponsesRequest] ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. + + Qwen3 has stricter requirements - it needs both start and end tokens + to be present, unlike other models that work with just the end token. For text abcxyz: - 'abc' goes to reasoning_content @@ -129,23 +49,24 @@ class Qwen3ReasoningParser(ReasoningParser): tuple[Optional[str], Optional[str]]: reasoning content and content """ - # Check if the model output contains the and tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # Check if the model output contains both and tokens. + if (self.start_token not in model_output + or self.end_token not in model_output): return None, model_output + # Check if the is present in the model output, remove it # if it is present. - model_output_parts = model_output.partition(self.think_start_token) + model_output_parts = model_output.partition(self.start_token) model_output = model_output_parts[2] if model_output_parts[ 1] else model_output_parts[0] + # Check if the model output contains the tokens. # If the end token is not found, return the model output as is. - if self.think_end_token not in model_output: + if self.end_token not in model_output: return None, model_output # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition( - self.think_end_token) + reasoning_content, _, content = model_output.partition(self.end_token) final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/seedoss_reasoning_parser.py b/vllm/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 0000000000000..5f4bbbf1557eb --- /dev/null +++ b/vllm/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +@ReasoningParserManager.register_module("seed_oss") +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for SeedOSS model. + + The SeedOSS model uses ... tokens to + denote reasoning content text. This parser extracts + the reasoning content from the model output. + Similar to DeepSeek R1, it supports cases + where the model doesn't generate the start token. + """ + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py index b7bee1974de5b..08466ca19b8a4 100644 --- a/vllm/transformers_utils/runai_utils.py +++ b/vllm/transformers_utils/runai_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib import os import shutil import signal @@ -56,12 +57,18 @@ class ObjectStorageModel: pull_files(): Pull model from object storage to the temporary directory. """ - def __init__(self) -> None: + def __init__(self, url: str) -> None: for sig in (signal.SIGINT, signal.SIGTERM): existing_handler = signal.getsignal(sig) signal.signal(sig, self._close_by_signal(existing_handler)) - self.dir = tempfile.mkdtemp() + dir_name = os.path.join( + tempfile.gettempdir(), + hashlib.sha256(str(url).encode()).hexdigest()[:8]) + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + os.makedirs(dir_name) + self.dir = dir_name def __del__(self): self._close() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5d165f1662383..0a7af79f7a177 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2546,10 +2546,10 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, - mutates_args: list[str], + mutates_args: Optional[list[str]] = None, fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, - dispatch_key: str = "CUDA", + dispatch_key: Optional[str] = None, tags: tuple[torch.Tag, ...] = (), ): """ @@ -2577,6 +2577,13 @@ def direct_register_custom_op( "the required dependencies.") return + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + dispatch_key = current_platform.dispatch_key + import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 2179bddae2435..ebc7a56ff906a 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]: return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) +def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: + """Check if the current configuration supports TRTLLM attention.""" + has_trtllm = supports_trtllm_attention() + return has_trtllm and (num_qo_heads % num_kv_heads == 0) + + def use_trtllm_attention( num_qo_heads: int, num_kv_heads: int, @@ -188,7 +194,9 @@ def use_trtllm_attention( max_seq_len: int, kv_cache_dtype: str, q_dtype: torch.dtype, + is_prefill: bool, has_sinks: bool = False, + has_spec: bool = False, ) -> bool: """Return ``True`` if TRTLLM attention is used.""" force_use_trtllm = force_use_trtllm_attention() @@ -214,6 +222,12 @@ def use_trtllm_attention( ) return False + if has_spec and not is_prefill: + # Speculative decoding requires TRTLLM attention for decodes + logger.info_once( + "Using TRTLLM attention (enabled for speculative decoding).") + return True + # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): if has_sinks: @@ -391,6 +405,7 @@ __all__ = [ "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "supports_trtllm_attention", + "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 457afb7e2c6ff..804c443eb1841 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from functools import reduce -from typing import Callable, TypeVar, Union, overload +from typing import Callable, TypeVar, Union, cast, overload _T = TypeVar("_T") _U = TypeVar("_U") @@ -30,10 +30,42 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: yield value +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, dict[str, _T]], +) -> Union[_U, dict[str, _U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, list[_T]], +) -> Union[_U, list[_U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, tuple[_T, ...]], +) -> Union[_U, tuple[_U, ...]]: + ... + + +@overload def json_map_leaves( func: Callable[[_T], _U], value: JSONTree[_T], ) -> JSONTree[_U]: + ... + + +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], +) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]: """Apply a function to each leaf in a nested JSON structure.""" if isinstance(value, dict): return {k: json_map_leaves(func, v) for k, v in value.items()} @@ -45,6 +77,33 @@ def json_map_leaves( return func(value) +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, dict[str, _T]], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, list[_T]], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, tuple[_T, ...]], + /, +) -> _T: + ... + + @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], @@ -65,10 +124,10 @@ def json_reduce_leaves( def json_reduce_leaves( - func: Callable[..., Union[_T, _U]], - value: JSONTree[_T], - initial: _U = ..., # type: ignore[assignment] - /, + func: Callable[..., Union[_T, _U]], + value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]], + initial: _U = cast(_U, ...), # noqa: B008 + /, ) -> Union[_T, _U]: """ Apply a function of two arguments cumulatively to each leaf in a diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d564cf9988ea6..a2e18f970bec0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -194,10 +194,9 @@ class FlashAttentionMetadataBuilder( self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -259,6 +258,15 @@ class FlashAttentionMetadataBuilder( self.aot_schedule = False aot_schedule = False + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and \ + num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): cache_dtype = self.cache_config.cache_dtype @@ -281,7 +289,7 @@ class FlashAttentionMetadataBuilder( page_size=self.block_size, causal=causal, window_size=self.aot_sliding_window, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None @@ -322,7 +330,6 @@ class FlashAttentionMetadataBuilder( max_seq_len=max_seq_len, causal=causal) # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata @@ -333,13 +340,6 @@ class FlashAttentionMetadataBuilder( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_actual_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f1..891108f961b5d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, +from vllm.utils.flashinfer import (can_use_trtllm_attention, + flashinfer_disable_q_quantization, supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -48,6 +49,16 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + return trtllm_gen_workspace_buffer + @triton.jit def _trtllm_prefill_attn_kvfp8_dequant( @@ -213,6 +224,7 @@ class FlashInferMetadata: # For flashinfer trtllm batch decode max_q_len: int + max_q_len_prefill: int max_seq_len: int seq_lens: torch.Tensor block_table_tensor: torch.Tensor @@ -240,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -292,6 +304,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): else: self.q_data_type = self.model_config.dtype + supports_spec_as_decode = \ + can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + self._init_reorder_batch_threshold(1, supports_spec_as_decode) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -406,7 +422,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + require_uniform=True) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -481,20 +498,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_last_page_len_np, ) + uses_spec_reorder = self.reorder_batch_threshold > 1 prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - has_sinks=self.has_sinks) + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_decode_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - has_sinks=self.has_sinks) + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder) if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): raise NotImplementedError( "FlashInfer backend currently does not support attention " @@ -511,6 +533,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): q_data_type=self.q_data_type, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, + max_q_len_prefill=max_q_len, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, @@ -567,6 +590,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ prefill_start] paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + + # Recompute max_q_len for the slice of requests we are using + # for prefills. This can be different from max_q_len when + # we have a non-uniform batch with some short decodes offloaded + # to the prefill pathway + query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + attn_metadata.max_q_len_prefill = \ + int(query_lens_prefill.max().item()) + if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( qo_indptr_cpu, @@ -597,7 +629,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_decodes <= self._decode_cudagraph_max_bs) if use_cudagraph: num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decodes)) + self.vllm_config.pad_for_cudagraph(num_decode_tokens)) # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing @@ -611,7 +643,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_decodes:num_input_tokens].fill_(1) else: - num_input_tokens = num_decodes + num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph) @@ -832,6 +864,9 @@ class FlashInferImpl(AttentionImpl): output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output + # When using spec decoding, num_decodes can be < num_decode_tokens + # because some decode requests may have more than one query token. + num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -862,10 +897,10 @@ class FlashInferImpl(AttentionImpl): else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() - workspace_buffer = prefill_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.block_table_tensor[ - num_decode_tokens:] - seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] + num_decodes:] + seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" @@ -909,7 +944,7 @@ class FlashInferImpl(AttentionImpl): workspace_buffer=workspace_buffer, block_tables=mock_block_table, seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len, + max_q_len=attn_metadata.max_q_len_prefill, max_kv_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, @@ -943,7 +978,7 @@ class FlashInferImpl(AttentionImpl): else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() - workspace_buffer = decode_wrapper._float_workspace_buffer + workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.\ block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] @@ -966,6 +1001,14 @@ class FlashInferImpl(AttentionImpl): assert self.o_sf_scale is None out = output[:num_decode_tokens] + if num_decode_tokens % attn_metadata.num_decodes != 0: + # This gets triggered when the dummy_run forces + # attention to be initialized with q_len = 0 + q_len_per_req = 1 + else: + q_len_per_req = \ + num_decode_tokens // attn_metadata.num_decodes + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -979,7 +1022,7 @@ class FlashInferImpl(AttentionImpl): sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, - ) + q_len_per_req=q_len_per_req) return output_padded diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 06a87a4a3c8b2..843958bc79de3 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Backend for GatedDeltaNet attention.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder( cudagraph_support = AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -76,7 +76,7 @@ class GDNAttentionMetadataBuilder( else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 - self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + self._init_reorder_batch_threshold(1, self.use_spec_decode) self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 3ff201d83a79b..0dc62d6680208 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar import torch @@ -35,7 +34,7 @@ class LinearAttentionMetadata: class LinearAttentionMetadataBuilder( AttentionMetadataBuilder[LinearAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9970331a6042c..ef342ce421ae1 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -16,7 +16,7 @@ M = TypeVar("M") class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a177117a50bd1..3e8dba14ee2e9 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ return curr_o @ W_O import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Generic, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import torch from tqdm import tqdm @@ -204,7 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -434,7 +434,35 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 + + @staticmethod + def determine_chunked_prefill_workspace_size( + vllm_config: VllmConfig) -> int: + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + chunked_prefill_workspace_size = min( + # Try for 8 full length request or at least 4 pages per-request + max(8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 64 * 1024) + + # Enforce that we enough for at least 1 page per request + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size) + + return chunked_prefill_workspace_size def __init__(self, kv_cache_spec: AttentionSpec, @@ -448,7 +476,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config - cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config self.device = device @@ -468,22 +495,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 64 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size + self.chunked_prefill_workspace_size = \ + self.determine_chunked_prefill_workspace_size(vllm_config) + if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore @@ -999,6 +1013,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.dcp_world_size: Optional[int] = None + self.chunked_prefill_workspace_size = \ + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config()) + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -1513,6 +1531,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): " for MLACommonImpl") if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (self.chunked_prefill_workspace_size, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 4ad9a13b61d8e..652b1cdb6b767 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 512 + reorder_batch_threshold: int = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -99,7 +99,7 @@ class FlashAttnMLAMetadataBuilder( # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. - self.__class__.reorder_batch_threshold = 1 \ + self.reorder_batch_threshold = 1 \ if get_dcp_group().world_size > 1 else self.reorder_batch_threshold def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 428e409659798..df7f0d2310ab4 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -41,7 +41,7 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( AttentionMetadataBuilder[ShortConvAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f837439f953e8..0c6e0dfefd8a2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: ClassVar[Optional[int]] = None + reorder_batch_threshold: Optional[int] = None @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -246,6 +246,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): self.vllm_config = vllm_config self.device = device + def _init_reorder_batch_threshold( + self, + reorder_batch_threshold: int = 1, + supports_spec_as_decode: bool = False) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + if self.reorder_batch_threshold is not None \ + and supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if (speculative_config is not None + and speculative_config.num_speculative_tokens is not None): + self.reorder_batch_threshold = \ + 1 + speculative_config.num_speculative_tokens + @abstractmethod def build(self, common_prefix_len: int, @@ -703,9 +719,9 @@ def subclass_attention_backend( def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, -) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -714,6 +730,9 @@ def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. + require_uniform: If True, requires that all decode requests have the + same query length. When set, some queries may be considered prefills + even if they are <= decode_threshold, in order to ensure uniformity. Returns: num_decodes: The number of decode requests. @@ -726,11 +745,20 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold: + if max_query_len <= decode_threshold and \ + (not require_uniform or decode_threshold <= 1): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] - is_prefill = query_lens > decode_threshold + if query_lens[0].item() > decode_threshold: + # first request is not decode, so no decode requests + return 0, num_reqs, 0, num_tokens + + if require_uniform: + is_prefill = query_lens != query_lens[0] + else: + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 @@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch +def reshape_query_for_spec_decode(query: torch.Tensor, + batch_size: int) -> torch.Tensor: + """ + Reshapes the query tensor for the specified batch size, so that + it has shape (batch_size, seq_len, num_heads, head_dim). + """ + assert query.dim() == 3, f"query must be 3D, got {query.dim()}D" + total_tokens = query.shape[0] + num_heads = query.shape[1] + head_dim = query.shape[2] + assert total_tokens % batch_size == 0, ( + f"{total_tokens=} is not divisible by {batch_size=}") + seq_len = total_tokens // batch_size + return query.view(batch_size, seq_len, num_heads, head_dim) + + +def reshape_attn_output_for_spec_decode( + attn_output: torch.Tensor) -> torch.Tensor: + """ + Reshapes the attention output tensor, so that + the batch_size and seq_len dimensions are combined. + """ + if attn_output.dim() == 3: + # Already in the correct shape + return attn_output + assert attn_output.dim() == 4, \ + f"attn_output must be 4D, got {attn_output.dim()}D" + total_tokens = attn_output.shape[0] * attn_output.shape[1] + return attn_output.view(total_tokens, attn_output.shape[2], + attn_output.shape[3]) + + KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca334912353..d5a6c4c1db52d 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -197,7 +197,7 @@ class XFormersAttentionMetadata: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__( self, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d1e1c1c8d0382..3cc738304821b 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable -from typing import Optional +from typing import Any, Optional, Union from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, BlockRemoved, BlockStored, @@ -19,6 +18,103 @@ from vllm.v1.request import Request logger = init_logger(__name__) +class BlockHashToBlockMap: + """ + Cache of blocks that are used for prefix caching. It caches blocks + from hash directly to a block or multiple blocks + (i.e. {block_hash: KVCacheBlocks}) + - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks + would simply be a KVCacheBlock. + - Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock} + + A cached block is a full block with a block hash that can be used + for prefix caching. + The cached block may be used by running requests or in the + free_block_queue that could potentially be evicted. + + NOTE #1: We currently don't de-duplicate the blocks in the cache, + meaning that if a block becomes full and is cached, we don't check + if there is already an identical block in the cache. This is because + we want to make sure the allocated block IDs won't change so that + block tables are append-only. + NOTE #2: The union type is introduced in order to reduce GC costs + from the inner dict. + """ + + def __init__(self): + self._cache: dict[BlockHashWithGroupId, + Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {} + + def get_one_block(self, + key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + """ + Gets any block with the given block hash key. + """ + blocks = self._cache.get(key) + if blocks is not None: + if isinstance(blocks, KVCacheBlock): + return blocks + if isinstance(blocks, dict): + return next(iter(blocks.values())) + self._unexpected_blocks_type(blocks) + return None + + def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: + """ + Inserts the KVCacheBlock to the cache + """ + blocks = self._cache.get(key) + if blocks is None: + # When key is not found, attach a single block to the key + self._cache[key] = block + elif isinstance(blocks, KVCacheBlock): + # If there's a block with the same key, merge the original block + # and the new block into a dict + self._cache[key] = {blocks.block_id: blocks, block.block_id: block} + elif isinstance(blocks, dict): + # If it's already a dict, simply insert the block + blocks[block.block_id] = block + else: + self._unexpected_blocks_type(blocks) + + def pop(self, key: BlockHashWithGroupId, + block_id: int) -> Optional[KVCacheBlock]: + """ + Checks if block_hash exists and pop block_id from the cache + """ + blocks = self._cache.pop(key, None) + if blocks is None: + # block_hash not found in the cache + return None + # TODO(Jialin): If key is found, block_id should always present + # in blocks. We currently keep the original behaviour for safety. + # + # Will add block_id == blocks.block_id assertion and + # use del blocks[block_id] instead as followup. + if isinstance(blocks, KVCacheBlock): + if blocks.block_id == block_id: + return blocks + # If the single block ID doesn't match, we should put the + # block back (it should happen rarely) + self._cache[key] = blocks + return None + if isinstance(blocks, dict): + # Try to pop block_id from the block dict, and if dict still + # contain blocks, put back to the cache. + block = blocks.pop(block_id, None) + if len(blocks) > 0: + self._cache[key] = blocks + return block + self._unexpected_blocks_type(blocks) + return None + + def __len__(self) -> int: + return len(self._cache) + + def _unexpected_blocks_type(self, blocks: Any) -> None: + raise AssertionError(f"Invalid KV cache block type {type(blocks)}") + + class BlockPool: """BlockPool that manages KVCacheBlocks. It provides methods to allocate, free and cache the kv cache blocks. The @@ -51,17 +147,9 @@ class BlockPool: # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[ - int, KVCacheBlock]] = defaultdict(dict) + # Cache for block lookup + self.cached_block_hash_to_block: BlockHashToBlockMap = \ + BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -90,12 +178,11 @@ class BlockPool: for group_id in kv_cache_group_ids: block_hash_with_group_id = make_block_hash_with_group_id( block_hash, group_id) - cached_blocks_one_group = self.cached_block_hash_to_block.get( + block = self.cached_block_hash_to_block.get_one_block( block_hash_with_group_id) - if not cached_blocks_one_group: + if not block: return None - first_block = next(iter(cached_blocks_one_group.values())) - cached_blocks.append(first_block) + cached_blocks.append(block) return cached_blocks def cache_full_blocks( @@ -140,8 +227,8 @@ class BlockPool: block_hash_with_group_id = make_block_hash_with_group_id( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block[block_hash_with_group_id][ - blk.block_id] = blk + self.cached_block_hash_to_block.insert(block_hash_with_group_id, + blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) @@ -211,15 +298,14 @@ class BlockPool: if block_hash is None: # The block doesn't have hash, eviction is not needed return False - blocks_by_id = self.cached_block_hash_to_block.get(block_hash) - if blocks_by_id is None: - # block_hash not found in cached_block_hash_to_block, + + if self.cached_block_hash_to_block.pop(block_hash, + block.block_id) is None: + # block not found in cached_block_hash_to_block, # eviction is not needed return False + block.reset_hash() - blocks_by_id.pop(block.block_id, None) - if len(blocks_by_id) == 0: - del self.cached_block_hash_to_block[block_hash] if self.enable_kv_cache_events: # FIXME (Chen): Not sure whether we should return `hash_value` @@ -283,7 +369,7 @@ class BlockPool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = BlockHashToBlockMap() # Remove all hashes from all blocks. for block in self.blocks: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index f0076b2d81dbf..52264e41e7a18 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -411,6 +411,19 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_inter_token_latency = make_per_engine( histogram_inter_token_latency, engine_indexes, model_name) + histogram_request_time_per_output_token = self._histogram_cls( + name="vllm:request_time_per_output_token_seconds", + documentation= + "Histogram of time_per_output_token_seconds per request.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=labelnames) + self.histogram_request_time_per_output_token = make_per_engine( + histogram_request_time_per_output_token, engine_indexes, + model_name) + request_latency_buckets = [ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 @@ -583,6 +596,8 @@ class PrometheusStatLogger(StatLoggerBase): finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request[engine_idx].observe( finished_request.num_generation_tokens) + self.histogram_request_time_per_output_token[engine_idx].observe( + finished_request.mean_time_per_output_token) if finished_request.max_tokens_param: self.histogram_max_tokens_request[engine_idx].observe( finished_request.max_tokens_param) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 0eff557336bc0..296c39e8cdb5c 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -86,6 +86,7 @@ class FinishedRequestStats: prefill_time: float = 0.0 inference_time: float = 0.0 decode_time: float = 0.0 + mean_time_per_output_token: float = 0.0 class IterationStats: @@ -177,6 +178,12 @@ class IterationStats: # Any preemptions during prefill or decode are included inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + # Do not count the token generated by the prefill phase + mean_time_per_output_token = (decode_time / + (req_stats.num_generation_tokens - 1) + if req_stats.num_generation_tokens - + 1 > 0 else 0) + finished_req = \ FinishedRequestStats(finish_reason=finish_reason, e2e_latency=e2e_latency, @@ -186,7 +193,8 @@ class IterationStats: queued_time=queued_time, prefill_time=prefill_time, inference_time=inference_time, - decode_time=decode_time) + decode_time=decode_time, + mean_time_per_output_token=mean_time_per_output_token) self.finished_requests.append(finished_req) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a9e0a38fe3417..5cae7df704701 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional, Protocol +from typing import Optional import numpy as np import torch @@ -37,17 +37,6 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 -class EagleAttentionMetadata(Protocol): - # Required attributes - num_actual_tokens: int - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - class EagleProposer: def __init__( @@ -120,7 +109,7 @@ class EagleProposer: with_numpy=True) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type, ...] + self.allowed_attn_types: Optional[tuple] = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -129,9 +118,6 @@ class EagleProposer: AiterFlashAttentionMetadata) rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) - else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree @@ -266,7 +252,8 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) - if not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and \ + not isinstance(attn_metadata, self.allowed_attn_types): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: " diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 55b4792fe010d..a853e6540719e 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -108,7 +108,9 @@ class XgrammarBackend(StructuredOutputBackend): end=s["end"], ) for s in s_tag["structures"] ] - ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) + structural_tag = xgr.StructuralTag.from_legacy_structural_tag( + tags, s_tag["triggers"]) + ctx = self.compiler.compile_structural_tag(structural_tag) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -318,6 +320,8 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: end=s["end"], ) for s in s_tag["structures"] ] - xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + structural_tag = xgr.StructuralTag.from_legacy_structural_tag( + tags, s_tag["triggers"]) + xgr.Grammar.from_structural_tag(structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index fd84b4a111f58..ec4417290f611 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -375,8 +375,22 @@ def report_usage_stats( }) +_PROFILER_FUNC = None + + def record_function_or_nullcontext(name: str) -> AbstractContextManager: + global _PROFILER_FUNC + + # fast path assume it is set + if _PROFILER_FUNC is not None: + return _PROFILER_FUNC(name) + + func = contextlib.nullcontext if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: - return record_function(name) - else: - return contextlib.nullcontext() + func = record_function + elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: + import nvtx + func = nvtx.annotate + + _PROFILER_FUNC = func + return func(name) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f785824958147..ee339e22cea90 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2807,7 +2807,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=self.max_model_len, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) @@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, allow_microbatching: bool = True, @@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_tokens: Number of tokens to run the dummy forward pass. cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is @@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode in { + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } @@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif uniform_decode: assert not create_mixed_batch num_reqs = cdiv(num_tokens, max_query_len) - assert num_reqs <= max_num_reqs, \ - f"Do not capture num_reqs {num_reqs} > max_num_reqs " \ - f"{max_num_reqs} for uniform batch. Num tokens: " \ - f"{num_tokens}, max_query_len: {max_query_len}" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - if cudagraph_runtime_mode == CUDAGraphMode.NONE: - batch_descriptor = None - else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) - # sanity check - assert cudagraph_runtime_mode == _cg_mode, ( + + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for cudagraph capture + assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ + cudagraph_runtime_mode == _cg_mode, ( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + else: + cudagraph_runtime_mode = _cg_mode if ubatch_slices is not None: num_tokens = num_tokens // 2 diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0b40ac6a7d629..7c7db72198406 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -170,6 +170,20 @@ class Worker(WorkerBase): current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment BEFORE taking + # memory snapshot + # This ensures NCCL buffers are allocated before we measure + # available memory + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Now take memory snapshot after NCCL is initialized gc.collect() torch.cuda.empty_cache() @@ -191,13 +205,6 @@ class Worker(WorkerBase): else: raise RuntimeError( f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) - # Set random seed. - set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4cbf991a14c11..4a2adb1e6510d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1795,7 +1795,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=self.max_model_len, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index fc72b954df9cf..d4f0a65f2a164 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -3,7 +3,7 @@ """A TPU worker class.""" import os -from typing import Any, Optional +from typing import Any, Callable, Optional, TypeVar import torch import torch.distributed @@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) +_R = TypeVar("_R") + if not USE_TPU_COMMONS: logger.info("tpu_commons not found, using vLLM's TPUWorker.") import torch_xla.core.xla_model as xm @@ -333,6 +335,10 @@ class TPUWorker: def shutdown(self) -> None: self.model_runner.ensure_kv_transfer_shutdown() + def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: + """Apply a function on the model inside this worker.""" + return fn(self.get_model()) + if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker