diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
index 7479c43977d78..2267718f75ca5 100644
--- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
@@ -73,12 +73,11 @@ function cpu_tests() {
pytest -x -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs"
- # Note: disable it until supports V1
- # Run AWQ test
- # docker exec cpu-test-"$NUMA_NODE" bash -c "
- # set -e
- # pytest -x -s -v \
- # tests/quantization/test_ipex_quant.py"
+ # Run AWQ/GPTQ test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -x -s -v \
+ tests/quantization/test_cpu_wna16.py"
# Run multi-lora tests
docker exec cpu-test-"$NUMA_NODE" bash -c "
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index e232000511c31..2471b509a9fff 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -1068,7 +1068,7 @@ steps:
# this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/compile/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
- # Wrap with quotes to escape yaml
+ # Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
- label: Blackwell Fusion E2E Tests # 30 min
@@ -1095,10 +1095,11 @@ steps:
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
-- label: Blackwell GPT-OSS Eval
+- label: ROCm GPT-OSS Eval
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
- gpu: b200
+ agent_pool: mi325_1
+ mirror_hardwares: [amdproduction]
optional: true # run on nightlies
source_file_dependencies:
- tests/evals/gpt_oss
@@ -1107,7 +1108,7 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
+ - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Blackwell Quantized MoE Test
timeout_in_minutes: 60
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 52539728215bb..4ac76aba67b9c 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -478,10 +478,11 @@ steps:
- vllm/
- tests/compile
commands:
+ # fp8 kv scales not supported on sm89, tested on Blackwell instead
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
+ - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test
timeout_in_minutes: 20
@@ -925,7 +926,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
-- label: Blackwell Fusion Tests # 30 min
+- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
@@ -946,7 +947,9 @@ steps:
- pytest -v -s tests/compile/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml
- - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
+ - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
+ # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
+ - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40
@@ -969,8 +972,6 @@ steps:
- nvidia-smi
# Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py
- # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
@@ -1266,7 +1267,8 @@ steps:
- pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
+ - pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py
diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml
index 8d40aa587bf00..42b05ecd5ac06 100644
--- a/.github/workflows/macos-smoke-test.yml
+++ b/.github/workflows/macos-smoke-test.yml
@@ -1,6 +1,9 @@
name: macOS Apple Silicon Smoke Test
on:
+ push:
+ branches:
+ - main
workflow_dispatch: # Manual trigger
jobs:
@@ -19,13 +22,15 @@ jobs:
pyproject.toml
python-version: '3.12'
- - name: Install dependencies
+ - name: Create virtual environment
run: |
- uv pip install -r requirements/cpu-build.txt
- uv pip install -r requirements/cpu.txt
+ uv venv
+ echo "$GITHUB_WORKSPACE/.venv/bin" >> "$GITHUB_PATH"
- - name: Build vLLM
- run: uv pip install -v -e .
+ - name: Install dependencies and build vLLM
+ run: |
+ uv pip install -r requirements/cpu.txt --index-strategy unsafe-best-match
+ uv pip install -e .
env:
CMAKE_BUILD_PARALLEL_LEVEL: 4
diff --git a/.gitignore b/.gitignore
index 50070d7898fe6..7cda86478664f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,9 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
+# OpenAI triton kernels copied from source
+vllm/third_party/triton_kernels/*
+
# triton jit
.triton
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dcc44be87e557..ae8e6175443f3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -512,9 +512,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
@@ -619,9 +619,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
@@ -695,7 +695,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
@@ -741,9 +741,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
@@ -861,7 +861,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
# Hadacore kernels
- cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
@@ -1030,6 +1030,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()
+# For CUDA and HIP builds also build the triton_kernels external package.
+if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
+ include(cmake/external_projects/triton_kernels.cmake)
+endif()
+
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
diff --git a/benchmarks/benchmark_batch_invariance.py b/benchmarks/benchmark_batch_invariance.py
new file mode 100755
index 0000000000000..b5c16c42de467
--- /dev/null
+++ b/benchmarks/benchmark_batch_invariance.py
@@ -0,0 +1,380 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
+
+This benchmark runs the same workload twice:
+1. With VLLM_BATCH_INVARIANT=0 (baseline)
+2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
+
+And reports the timing and throughput metrics for comparison.
+
+Environment variables:
+ VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
+ VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
+ VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
+ VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
+ VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
+ VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
+ VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
+ VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
+ VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
+ VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
+ VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
+
+Example usage:
+ # Benchmark qwen3 (default)
+ python benchmarks/benchmark_batch_invariance.py
+
+ # Benchmark deepseek with 8 GPUs
+ VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\
+ python benchmarks/benchmark_batch_invariance.py
+
+ # Quick test with fewer trials
+ VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\
+ python benchmarks/benchmark_batch_invariance.py
+"""
+
+import contextlib
+import os
+import random
+import time
+
+from vllm import LLM, SamplingParams
+from vllm.platforms import current_platform
+
+
+def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
+ """Generate a random prompt for benchmarking."""
+ prompt_templates = [
+ "Question: What is the capital of France?\nAnswer: The capital of France is",
+ "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
+ "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
+ "Once upon a time in a distant galaxy, there lived",
+ "The old man walked slowly down the street, remembering",
+ "In the year 2157, humanity finally discovered",
+ "To implement a binary search tree in Python, first we need to",
+ "The algorithm works by iterating through the array and",
+ "Here's how to optimize database queries using indexing:",
+ "The Renaissance was a period in European history that",
+ "Climate change is caused by several factors including",
+ "The human brain contains approximately 86 billion neurons which",
+ "I've been thinking about getting a new laptop because",
+ "Yesterday I went to the store and bought",
+ "My favorite thing about summer is definitely",
+ ]
+
+ base_prompt = random.choice(prompt_templates)
+
+ if max_words < min_words:
+ max_words = min_words
+ target_words = random.randint(min_words, max_words)
+
+ if target_words > 50:
+ padding_text = (
+ " This is an interesting topic that deserves more explanation. "
+ * (target_words // 50)
+ )
+ base_prompt = base_prompt + padding_text
+
+ return base_prompt
+
+
+def run_benchmark_with_batch_invariant(
+ model: str,
+ tp_size: int,
+ max_batch_size: int,
+ num_trials: int,
+ min_prompt: int,
+ max_prompt: int,
+ max_tokens: int,
+ temperature: float,
+ gpu_mem_util: float,
+ max_model_len: int,
+ backend: str,
+ batch_invariant: bool,
+ seed: int = 12345,
+) -> dict:
+ """
+ Run the benchmark with the specified configuration.
+
+ Returns a dict with timing and throughput metrics.
+ """
+ random.seed(seed)
+
+ # Set environment variables
+ os.environ["VLLM_ATTENTION_BACKEND"] = backend
+ if batch_invariant:
+ os.environ["VLLM_BATCH_INVARIANT"] = "1"
+ else:
+ os.environ["VLLM_BATCH_INVARIANT"] = "0"
+
+ print(f"\n{'=' * 80}")
+ print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}")
+ print(f" Model: {model}")
+ print(f" TP Size: {tp_size}")
+ print(f" Backend: {backend}")
+ print(f" Max Batch Size: {max_batch_size}")
+ print(f" Trials: {num_trials}")
+ print(f" Max Tokens: {max_tokens}")
+ print(f"{'=' * 80}\n")
+
+ sampling = SamplingParams(
+ temperature=temperature,
+ top_p=0.95,
+ max_tokens=max_tokens,
+ seed=20240919,
+ )
+
+ needle_prompt = "There once was a "
+
+ llm = None
+ try:
+ # Create LLM engine
+ start_init = time.perf_counter()
+ llm = LLM(
+ model=model,
+ max_num_seqs=max_batch_size,
+ gpu_memory_utilization=gpu_mem_util,
+ max_model_len=max_model_len,
+ dtype="bfloat16",
+ tensor_parallel_size=tp_size,
+ enable_prefix_caching=False,
+ )
+ init_time = time.perf_counter() - start_init
+ print(f"Engine initialization time: {init_time:.2f}s\n")
+
+ # Generate baseline
+ print("Generating baseline (warmup)...")
+ baseline_out = llm.generate([needle_prompt], sampling)
+ assert len(baseline_out) == 1
+ baseline_text = baseline_out[0].outputs[0].text
+ print(f"Baseline output: '{baseline_text[:50]}...'\n")
+
+ # Run trials and measure timing
+ trial_times: list[float] = []
+ total_tokens = 0
+ total_prompts = 0
+
+ for trial in range(num_trials):
+ # Create a batch
+ prompts: list[str] = []
+ batch_size = random.randint(max_batch_size // 2, max_batch_size)
+ needle_pos = random.randint(0, batch_size - 1)
+ for i in range(batch_size):
+ if i == needle_pos:
+ prompts.append(needle_prompt)
+ else:
+ prompts.append(_random_prompt(min_prompt, max_prompt))
+
+ # Measure time for this trial
+ start_time = time.perf_counter()
+ outputs = llm.generate(prompts, sampling)
+ trial_time = time.perf_counter() - start_time
+
+ trial_times.append(trial_time)
+ total_prompts += len(prompts)
+
+ # Count tokens
+ for output in outputs:
+ if output.outputs:
+ total_tokens += len(output.outputs[0].token_ids)
+
+ print(
+ f"Trial {trial + 1}/{num_trials}: "
+ f"batch_size={batch_size}, "
+ f"time={trial_time:.2f}s"
+ )
+
+ # Verify needle output still matches
+ needle_output = outputs[needle_pos]
+ assert needle_output.prompt == needle_prompt
+
+ # Compute statistics
+ avg_time = sum(trial_times) / len(trial_times)
+ min_time = min(trial_times)
+ max_time = max(trial_times)
+ throughput = total_tokens / sum(trial_times)
+ prompts_per_sec = total_prompts / sum(trial_times)
+
+ print(f"\n{'=' * 80}")
+ print("RESULTS:")
+ print(f" Average time per trial: {avg_time:.2f}s")
+ print(f" Min time: {min_time:.2f}s")
+ print(f" Max time: {max_time:.2f}s")
+ print(f" Total tokens generated: {total_tokens}")
+ print(f" Total prompts processed: {total_prompts}")
+ print(f" Throughput: {throughput:.2f} tokens/s")
+ print(f" Prompts/s: {prompts_per_sec:.2f}")
+ print(f"{'=' * 80}\n")
+
+ return {
+ "init_time": init_time,
+ "avg_time": avg_time,
+ "min_time": min_time,
+ "max_time": max_time,
+ "total_tokens": total_tokens,
+ "total_prompts": total_prompts,
+ "throughput": throughput,
+ "prompts_per_sec": prompts_per_sec,
+ "trial_times": trial_times,
+ }
+
+ finally:
+ # Cleanup
+ if llm is not None:
+ with contextlib.suppress(Exception):
+ llm.shutdown()
+
+
+def main():
+ # Check platform support
+ if not (current_platform.is_cuda() and current_platform.has_device_capability(90)):
+ print("ERROR: Requires CUDA and >= Hopper (SM90)")
+ print(f"Current platform: {current_platform.device_type}")
+ if current_platform.is_cuda():
+ print(f"Device capability: {current_platform.get_device_capability()}")
+ return 1
+
+ # Read configuration from environment
+ model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B")
+ tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1"))
+ max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128"))
+ num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5"))
+ min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024"))
+ max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048"))
+ max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128"))
+ temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0"))
+ gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4"))
+ max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120"))
+ backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN")
+
+ print("\n" + "=" * 80)
+ print("VLLM BATCH INVARIANCE BENCHMARK")
+ print("=" * 80)
+ print("\nConfiguration:")
+ print(f" Model: {model}")
+ print(f" Tensor Parallel Size: {tp_size}")
+ print(f" Attention Backend: {backend}")
+ print(f" Max Batch Size: {max_batch_size}")
+ print(f" Number of Trials: {num_trials}")
+ print(f" Prompt Length Range: {min_prompt}-{max_prompt} words")
+ print(f" Max Tokens to Generate: {max_tokens}")
+ print(f" Temperature: {temperature}")
+ print(f" GPU Memory Utilization: {gpu_mem_util}")
+ print(f" Max Model Length: {max_model_len}")
+ print("=" * 80)
+
+ # Run benchmark WITHOUT batch invariance (baseline)
+ print("\n" + "=" * 80)
+ print("PHASE 1: Running WITHOUT batch invariance (baseline)")
+ print("=" * 80)
+ baseline_results = run_benchmark_with_batch_invariant(
+ model=model,
+ tp_size=tp_size,
+ max_batch_size=max_batch_size,
+ num_trials=num_trials,
+ min_prompt=min_prompt,
+ max_prompt=max_prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ gpu_mem_util=gpu_mem_util,
+ max_model_len=max_model_len,
+ backend=backend,
+ batch_invariant=False,
+ )
+
+ # Run benchmark WITH batch invariance
+ print("\n" + "=" * 80)
+ print("PHASE 2: Running WITH batch invariance")
+ print("=" * 80)
+ batch_inv_results = run_benchmark_with_batch_invariant(
+ model=model,
+ tp_size=tp_size,
+ max_batch_size=max_batch_size,
+ num_trials=num_trials,
+ min_prompt=min_prompt,
+ max_prompt=max_prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ gpu_mem_util=gpu_mem_util,
+ max_model_len=max_model_len,
+ backend=backend,
+ batch_invariant=True,
+ )
+
+ # Compare results
+ print("\n" + "=" * 80)
+ print("COMPARISON: Batch Invariance vs Baseline")
+ print("=" * 80)
+
+ init_overhead_pct = (
+ (batch_inv_results["init_time"] - baseline_results["init_time"])
+ / baseline_results["init_time"]
+ * 100
+ )
+ time_overhead_pct = (
+ (batch_inv_results["avg_time"] - baseline_results["avg_time"])
+ / baseline_results["avg_time"]
+ * 100
+ )
+ throughput_change_pct = (
+ (batch_inv_results["throughput"] - baseline_results["throughput"])
+ / baseline_results["throughput"]
+ * 100
+ )
+
+ print("\nInitialization Time:")
+ print(f" Baseline: {baseline_results['init_time']:.2f}s")
+ print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s")
+ print(f" Overhead: {init_overhead_pct:+.2f}%")
+
+ print("\nAverage Trial Time:")
+ print(f" Baseline: {baseline_results['avg_time']:.2f}s")
+ print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s")
+ print(f" Overhead: {time_overhead_pct:+.2f}%")
+
+ print("\nThroughput (tokens/s):")
+ print(f" Baseline: {baseline_results['throughput']:.2f}")
+ print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}")
+ print(f" Change: {throughput_change_pct:+.2f}%")
+
+ print("\nPrompts/s:")
+ print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}")
+ print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}")
+
+ print("\n" + "=" * 80)
+ print("SUMMARY")
+ print("=" * 80)
+ if time_overhead_pct > 0:
+ print(
+ f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% "
+ "overhead"
+ )
+ else:
+ print(
+ f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% "
+ "faster (unexpected!)"
+ )
+
+ if abs(throughput_change_pct) < 1.0:
+ print("Throughput difference is negligible (< 1%)")
+ elif throughput_change_pct < 0:
+ print(
+ f"Throughput decreased by {-throughput_change_pct:.1f}% "
+ "with batch invariance"
+ )
+ else:
+ print(
+ f"Throughput increased by {throughput_change_pct:.1f}% "
+ "with batch invariance (unexpected!)"
+ )
+
+ print("=" * 80 + "\n")
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
index 027f67ad4db69..e07d6c776bc00 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
@@ -255,8 +255,8 @@ def bench_run(
torch.cuda.synchronize()
# Timing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies = []
for _ in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index c99951aa27826..a1af0b8aec3d0 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -185,8 +185,8 @@ def benchmark_config(
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
index efa5a7386027e..b8913a217c608 100644
--- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
@@ -105,8 +105,8 @@ def benchmark_permute(
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_per_token_group_quant.py b/benchmarks/kernels/benchmark_per_token_group_quant.py
index bdc1eb733084e..eba4d510258b6 100644
--- a/benchmarks/kernels/benchmark_per_token_group_quant.py
+++ b/benchmarks/kernels/benchmark_per_token_group_quant.py
@@ -30,8 +30,8 @@ def _time_cuda(
fn()
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
start.record()
for _ in range(bench_iters):
diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
index a5887aafd30d6..de01ff197eab7 100644
--- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
@@ -253,8 +253,8 @@ def benchmark(
)
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
# Benchmark
latencies: list[float] = []
diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
index 29ce18234dfa0..1d0d6fbb9a470 100644
--- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
@@ -127,8 +127,8 @@ def benchmark_decode(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
index 2a25d03748112..84bde723abf7f 100644
--- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
@@ -139,8 +139,8 @@ def benchmark_prefill(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
index ab54f81985bc2..b52500c8c5217 100644
--- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py
+++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
@@ -183,8 +183,8 @@ def benchmark_config(
run()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md
index f5b5c6c97d484..b0be1e3a69a66 100644
--- a/benchmarks/multi_turn/README.md
+++ b/benchmarks/multi_turn/README.md
@@ -55,6 +55,10 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75
----------------------------------------------------------------------------------------------------
```
+If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec`
+and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the
+benchmark-only runtime so the reported throughput stays comparable).
+
### JSON configuration file for synthetic conversations generation
The input flag `--input-file` is used to determine the input conversations for the benchmark.
diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
index ae9e9753441aa..e23f6b923f1b9 100644
--- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py
+++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
@@ -561,8 +561,11 @@ async def client_main(
f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
)
- random.seed(args.seed)
- np.random.seed(args.seed)
+ # Set unique seed per client (each client runs in its own process)
+ # Add 1 to ensure no client uses the same seed as the main process
+ client_seed = args.seed + client_id + 1
+ random.seed(client_seed)
+ np.random.seed(client_seed)
# Active conversations
active_convs: ConversationsMap = {}
@@ -1073,6 +1076,7 @@ def process_statistics(
verbose: bool,
gen_conv_args: GenConvArgs | None = None,
excel_output: bool = False,
+ warmup_runtime_sec: float | None = None,
) -> None:
if len(client_metrics) == 0:
logger.info("No samples to process")
@@ -1166,8 +1170,13 @@ def process_statistics(
# Convert milliseconds to seconds
runtime_sec = runtime_sec / 1000.0
requests_per_sec = float(len(df)) / runtime_sec
-
- params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec}
+ params = {
+ "runtime_sec": runtime_sec,
+ "requests_per_sec": requests_per_sec,
+ }
+ if warmup_runtime_sec is not None:
+ params["warmup_runtime_sec"] = warmup_runtime_sec
+ params["total_runtime_incl_warmup_sec"] = runtime_sec + warmup_runtime_sec
# Generate a summary of relevant metrics (and drop irrelevant data)
df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose()
@@ -1490,6 +1499,7 @@ async def main() -> None:
f"Invalid --warmup-percentage={args.warmup_percentage}"
) from None
+ # Set global seeds for main process
random.seed(args.seed)
np.random.seed(args.seed)
@@ -1548,6 +1558,8 @@ async def main() -> None:
url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop
)
+ warmup_runtime_sec: float | None = None
+
# Warm-up step
if args.warmup_step:
# Only send a single user prompt from every conversation.
@@ -1562,26 +1574,56 @@ async def main() -> None:
# all clients should finish their work before exiting
warmup_bench_args = bench_args._replace(early_stop=False)
- logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}")
+ logger.info("%sWarmup start%s", Color.PURPLE, Color.RESET)
+ warmup_start_ns = time.perf_counter_ns()
conversations, _ = await main_mp(
warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations
)
- logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}")
+ warmup_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - warmup_start_ns)
+ logger.info(
+ "%sWarmup runtime: %.3f sec (%.3f ms)%s",
+ Color.PURPLE,
+ warmup_runtime_sec,
+ warmup_runtime_sec * 1000,
+ Color.RESET,
+ )
+ logger.info("%sWarmup done%s", Color.PURPLE, Color.RESET)
# Run the benchmark
- start_time = time.perf_counter_ns()
+ benchmark_start_ns = time.perf_counter_ns()
client_convs, client_metrics = await main_mp(
client_args, req_args, bench_args, tokenizer, conversations
)
- total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time)
+ benchmark_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - benchmark_start_ns)
# Calculate requests per second
- total_runtime_sec = total_runtime_ms / 1000.0
- rps = len(client_metrics) / total_runtime_sec
+ requests_per_sec = len(client_metrics) / benchmark_runtime_sec
+ benchmark_runtime_ms = benchmark_runtime_sec * 1000.0
logger.info(
- f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec"
- f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}"
+ "%sAll clients finished, benchmark runtime: %.3f sec (%.3f ms), "
+ "requests per second: %.3f%s",
+ Color.GREEN,
+ benchmark_runtime_sec,
+ benchmark_runtime_ms,
+ requests_per_sec,
+ Color.RESET,
)
+ if warmup_runtime_sec is not None:
+ total_runtime_sec = benchmark_runtime_sec + warmup_runtime_sec
+ logger.info(
+ "%sWarmup runtime: %.3f sec (%.3f ms)%s",
+ Color.GREEN,
+ warmup_runtime_sec,
+ warmup_runtime_sec * 1000,
+ Color.RESET,
+ )
+ logger.info(
+ "%sTotal runtime (including warmup): %.3f sec (%.3f ms)%s",
+ Color.GREEN,
+ total_runtime_sec,
+ total_runtime_sec * 1000,
+ Color.RESET,
+ )
# Benchmark parameters
params = {
@@ -1606,6 +1648,7 @@ async def main() -> None:
verbose=args.verbose,
gen_conv_args=gen_conv_args,
excel_output=args.excel_output,
+ warmup_runtime_sec=warmup_runtime_sec,
)
if args.output_file is not None:
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
index aa84125818d10..fbbb03c5ed465 100644
--- a/cmake/cpu_extension.cmake
+++ b/cmake/cpu_extension.cmake
@@ -375,6 +375,7 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
+ "csrc/cpu/cpu_wna16.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC
diff --git a/cmake/external_projects/triton_kernels.cmake b/cmake/external_projects/triton_kernels.cmake
new file mode 100644
index 0000000000000..d35ad123dd9de
--- /dev/null
+++ b/cmake/external_projects/triton_kernels.cmake
@@ -0,0 +1,53 @@
+# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
+
+set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
+
+# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
+# be directly set to the triton_kernels python directory.
+if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
+ message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
+ FetchContent_Declare(
+ triton_kernels
+ SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
+ )
+
+else()
+ set(TRITON_GIT "https://github.com/triton-lang/triton.git")
+ message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
+ FetchContent_Declare(
+ triton_kernels
+ # TODO (varun) : Fetch just the triton_kernels directory from Triton
+ GIT_REPOSITORY https://github.com/triton-lang/triton.git
+ GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
+ GIT_PROGRESS TRUE
+ SOURCE_SUBDIR python/triton_kernels/triton_kernels
+ )
+endif()
+
+# Fetch content
+FetchContent_MakeAvailable(triton_kernels)
+
+if (NOT triton_kernels_SOURCE_DIR)
+ message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
+endif()
+
+if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
+ set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
+else()
+ set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
+endif()
+
+message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
+
+add_custom_target(triton_kernels)
+
+# Ensure the vllm/third_party directory exists before installation
+install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
+
+## Copy .py files to install directory.
+install(DIRECTORY
+ ${TRITON_KERNELS_PYTHON_DIR}
+ DESTINATION
+ vllm/third_party/triton_kernels/
+ COMPONENT triton_kernels
+ FILES_MATCHING PATTERN "*.py")
diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp
index 5de8a114b2b55..294b4f714a769 100644
--- a/csrc/cpu/cpu_attn_impl.hpp
+++ b/csrc/cpu/cpu_attn_impl.hpp
@@ -1,7 +1,6 @@
#ifndef CPU_ATTN_HPP
#define CPU_ATTN_HPP
-#include
#include
#include
@@ -12,6 +11,7 @@
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
+#include "utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16 };
@@ -754,7 +754,7 @@ class AttentionScheduler {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
- return 128 * 1024 >> 1; // use 50% of 128KB
+ return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index 7ddf028e6e131..6f51277f78440 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec {
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
+ explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
+
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec {
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
+ explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
+
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec {
explicit FP32Vec16(__m512 data) : reg(data) {}
+ // de-pack 4 bit values
+ explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
+ int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
+ int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
+ int64_t value_0 = value & mask_0;
+ int64_t value_1 = value & mask_1;
+ __m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
+ __m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
+ vec_0 = _mm_cvtepu8_epi16(vec_0);
+ vec_1 = _mm_cvtepu8_epi16(vec_1);
+ vec_1 = _mm_slli_epi16(vec_1, 4);
+ __m128i vec = _mm_or_si128(vec_0, vec_1);
+ __m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
+ reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
+ }
+
explicit FP32Vec16(const FP32Vec4& data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
@@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec {
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
- template
- float reduce_sub_sum(int idx) {
- static_assert(VEC_ELEM_NUM % group_size == 0);
- constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
- __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
- return _mm512_mask_reduce_add_ps(mask, reg);
- }
-
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
void save(float* ptr, const int elem_num) const {
@@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
+
+static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
+ void* ptr) {
+ __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
+ __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
+ vec_1 = _mm512_slli_epi32(vec_1, 16);
+ vec_0 = _mm512_or_si512(vec_0, vec_1);
+ _mm512_storeu_epi32(ptr, vec_0);
+}
+
+static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
+ void* ptr) {
+ __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
+ __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
+ vec_1 = _mm512_slli_epi32(vec_1, 16);
+ vec_0 = _mm512_or_si512(vec_0, vec_1);
+ _mm512_storeu_epi32(ptr, vec_0);
+}
+
#endif
inline void mem_barrier() { _mm_mfence(); }
diff --git a/csrc/cpu/cpu_wna16.cpp b/csrc/cpu/cpu_wna16.cpp
new file mode 100644
index 0000000000000..816d195506e52
--- /dev/null
+++ b/csrc/cpu/cpu_wna16.cpp
@@ -0,0 +1,402 @@
+#include "cpu_types.hpp"
+#include "scratchpad_manager.h"
+#include "utils.hpp"
+
+#ifdef CPU_CAPABILITY_AMXBF16
+ #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
+#endif
+#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
+
+#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
+
+#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
+
+template
+void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
+ int32_t stride) {
+ std::stringstream ss;
+ ss << std::fixed << std::setprecision(5) << name << ": [\n";
+ auto* curr_logits_buffer = ptr;
+ for (int32_t m = 0; m < row; ++m) {
+ for (int32_t n = 0; n < col; ++n) {
+ ss << curr_logits_buffer[n] << ", ";
+ }
+ ss << "\n";
+ curr_logits_buffer += stride;
+ }
+ ss << "]\n";
+ std::printf("%s", ss.str().c_str());
+}
+
+namespace {
+using cpu_utils::ISA;
+using cpu_utils::VecTypeTrait;
+
+template
+class Dequantizer4b {
+ public:
+ constexpr static int32_t pack_num = 32 / 4;
+ using scalar_vec_t = typename VecTypeTrait::vec_t;
+
+ public:
+ static void dequant(int32_t* __restrict__ q_weight,
+ scalar_t* __restrict__ weight,
+ scalar_t* __restrict__ scales,
+ int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
+ const int64_t scales_stride, const int64_t zeros_stride,
+ const int32_t k_size, const int32_t group_size) {
+ vec_op::FP32Vec16 lut;
+ if constexpr (has_zp) {
+ // AWQ
+ alignas(64) static const float LUT[16] = {
+ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
+ 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
+ lut = vec_op::FP32Vec16(LUT);
+ } else {
+ // GPTQ
+ alignas(64) static const float LUT[16] = {
+ -8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
+ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+ lut = vec_op::FP32Vec16(LUT);
+ }
+
+ // per 64-bits elem contains 16 output channels
+ int64_t* __restrict__ curr_q_weight = reinterpret_cast(q_weight);
+ int64_t* __restrict__ curr_zeros = reinterpret_cast(zeros);
+ scalar_t* __restrict__ curr_weight = weight;
+ scalar_t* __restrict__ curr_scale = scales;
+ vec_op::FP32Vec16 scale_0;
+ vec_op::FP32Vec16 scale_1;
+ vec_op::FP32Vec16 zero_0;
+ vec_op::FP32Vec16 zero_1;
+ int32_t group_counter = 0;
+ for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
+ int64_t qwb_0 = *curr_q_weight;
+ int64_t qwb_1 = *(curr_q_weight + 1);
+ vec_op::FP32Vec16 wb_0(qwb_0, lut);
+ vec_op::FP32Vec16 wb_1(qwb_1, lut);
+
+ if constexpr (!use_desc_act) {
+ if (group_counter == 0) {
+ scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
+ scale_1 = vec_op::FP32Vec16(scale_0);
+ curr_scale += scales_stride;
+
+ if constexpr (has_zp) {
+ zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
+ zero_1 = vec_op::FP32Vec16(zero_0);
+ curr_zeros += zeros_stride / 2;
+ }
+ }
+ } else {
+ int32_t g_idx_0 = g_idx[k_idx];
+ int32_t g_idx_1 = g_idx[k_idx + 1];
+ scale_0 = vec_op::FP32Vec16(
+ scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
+ scale_1 = vec_op::FP32Vec16(
+ scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
+ if constexpr (has_zp) {
+ zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
+ lut);
+ zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
+ lut);
+ }
+ }
+
+ if constexpr (has_zp) {
+ wb_0 = wb_0 - zero_0;
+ wb_1 = wb_1 - zero_1;
+ }
+
+ wb_0 = wb_0 * scale_0;
+ wb_1 = wb_1 * scale_1;
+
+ scalar_vec_t output_vec_0(wb_0);
+ scalar_vec_t output_vec_1(wb_1);
+
+ // AMX needs to interlave K elements to pack as 32 bits
+ if constexpr (isa == ISA::AMX) {
+ vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
+ } else {
+ output_vec_0.save(curr_weight);
+ output_vec_1.save(curr_weight + 16);
+ }
+
+ // update
+ curr_q_weight += 2;
+ curr_weight += 32;
+ if constexpr (!use_desc_act) {
+ group_counter += 2;
+ if (group_counter == group_size) {
+ group_counter = 0;
+ }
+ }
+ }
+ }
+};
+}; // namespace
+
+template
+void cpu_gemm_wna16_impl(
+ scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
+ scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
+ int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
+ scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
+ const int32_t k_size, const int64_t input_stride,
+ const int64_t output_stride, const int64_t scales_group_stride,
+ const int64_t zeros_group_stride, const int32_t group_num,
+ const int32_t group_size, const int64_t pack_factor) {
+ constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
+ constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
+ constexpr int32_t n_block_size = 16;
+ static_assert(gemm_n_tile_size % n_block_size == 0);
+ const int32_t thread_num = omp_get_max_threads();
+
+ // a simple schedule policy, just to hold more B tiles in L2 and make sure
+ // each thread has tasks
+ const int32_t n_partition_size = [&]() {
+ const int64_t cache_size = cpu_utils::get_l2_size();
+ int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
+ int64_t ps_thread_limit = n_size / thread_num;
+ ps_cache_limit =
+ std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
+ (int64_t)gemm_n_tile_size);
+ ps_thread_limit =
+ std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
+ (int64_t)gemm_n_tile_size);
+ return std::min(ps_cache_limit, ps_thread_limit);
+ }();
+ const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
+
+ // get buffer size
+ const int64_t b_buffer_size =
+ (((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
+ const int64_t c_buffer_size =
+ (((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
+ const int64_t b_buffer_offset = 0;
+ const int64_t c_buffer_offset = b_buffer_size;
+ const int64_t buffer_size = b_buffer_size + c_buffer_size;
+ DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
+ thread_num);
+
+ alignas(64) cpu_utils::Counter counter;
+ cpu_utils::Counter* counter_ptr = &counter;
+
+#pragma omp parallel for schedule(static, 1)
+ for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
+ scalar_t* __restrict__ b_buffer = nullptr;
+ float* __restrict__ c_buffer = nullptr;
+ {
+ uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
+ ->get_data() +
+ thread_id * buffer_size;
+ b_buffer = reinterpret_cast(buffer_ptr + b_buffer_offset);
+ c_buffer = reinterpret_cast(buffer_ptr + c_buffer_offset);
+ }
+
+ const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
+ const int64_t b_buffer_block_stride = n_block_size * k_size;
+ const int32_t zeros_block_stride = n_block_size / pack_factor;
+
+ gemm_t gemm;
+
+ for (;;) {
+ int32_t task_id = counter_ptr->acquire_counter();
+
+ if (task_id >= task_num) {
+ break;
+ }
+
+ const int32_t n_start_idx = task_id * n_partition_size;
+ const int32_t n_block_start_idx = n_start_idx / n_block_size;
+ const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
+ const int32_t n_block_num = n_num / n_block_size;
+ // std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
+ // thread_id, task_id, n_start_idx, n_num);
+
+ // dequant weight
+ {
+ int32_t* __restrict__ curr_q_weight =
+ q_weight + n_block_start_idx * q_weight_block_stride;
+ scalar_t* __restrict__ curr_b_buffer = b_buffer;
+ scalar_t* __restrict__ curr_scales = scales + n_start_idx;
+ int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
+ for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
+ dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
+ curr_zeros, g_idx, scales_group_stride,
+ zeros_group_stride, k_size, group_size);
+
+ // if (block_idx == 0 && n_start_idx == 0) {
+ // print_logits("depacked weight", curr_b_buffer, k_size,
+ // n_block_size, n_block_size);
+ // }
+
+ // update
+ curr_q_weight += q_weight_block_stride;
+ curr_b_buffer += b_buffer_block_stride;
+ curr_scales += n_block_size;
+ curr_zeros += zeros_block_stride;
+ }
+ }
+
+ // compute loop
+ {
+ const int32_t n_tile_num = n_num / gemm_n_tile_size;
+ scalar_t* __restrict__ curr_input = input;
+ scalar_t* __restrict__ init_bias = bias;
+ if (bias != nullptr) {
+ init_bias += n_start_idx;
+ }
+ scalar_t* __restrict__ init_output = output + n_start_idx;
+ for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
+ const int32_t curr_m_size =
+ std::min(gemm_m_tile_size, m_size - m_idx);
+ scalar_t* __restrict__ curr_b_buffer = b_buffer;
+ scalar_t* __restrict__ curr_bias = init_bias;
+ scalar_t* __restrict__ curr_output = init_output;
+ for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
+ gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
+ input_stride, b_buffer_block_stride, gemm_n_tile_size,
+ false);
+
+ if (bias != nullptr) {
+ cpu_micro_gemm::bias_epilogue(
+ c_buffer, curr_output, curr_bias, curr_m_size,
+ gemm_n_tile_size, output_stride);
+ curr_bias += gemm_n_tile_size;
+ } else {
+ cpu_micro_gemm::default_epilogue(
+ c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
+ output_stride);
+ }
+
+ curr_b_buffer +=
+ b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
+ curr_output += gemm_n_tile_size;
+ }
+ curr_input += gemm_m_tile_size * input_stride;
+ init_output += gemm_m_tile_size * output_stride;
+ }
+ }
+ }
+ }
+}
+
+void cpu_gemm_wna16(
+ const torch::Tensor& input, // [M, K]
+ const torch::Tensor&
+ q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
+ torch::Tensor& output, // [M, N]
+ const torch::Tensor& scales, // [group_num, N]
+ const std::optional&
+ zeros, // [group_num, N / pack_factor], packed as int32
+ const std::optional& g_idx, // [K]
+ const std::optional& bias, // [N]
+ const int64_t pack_factor, const std::string& isa_hint) {
+ using cpu_utils::ISA;
+ TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
+ const int32_t a_m_size = input.size(0);
+ const int32_t a_k_size = input.size(1);
+ const int64_t a_m_stride = input.stride(0);
+ const int32_t b_n_size = q_weight.size(0) * 16;
+ TORCH_CHECK_EQ(a_k_size % 32, 0);
+ TORCH_CHECK_EQ(b_n_size % 32, 0);
+ const int32_t group_num = scales.size(0);
+ const int32_t group_size = a_k_size / group_num;
+ TORCH_CHECK_EQ(group_size % 2, 0);
+ const int64_t scales_group_stride = scales.stride(0);
+ const int64_t output_m_stride = output.stride(0);
+
+ bool has_zp = zeros.has_value();
+ bool use_desc_act = g_idx.has_value();
+ TORCH_CHECK(!(has_zp && use_desc_act));
+
+ ISA isa = [&]() {
+ if (isa_hint == "amx") {
+ return ISA::AMX;
+ } else if (isa_hint == "vec") {
+ return ISA::VEC;
+ } else {
+ TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
+ }
+ }();
+
+ int32_t* zeros_ptr = has_zp ? zeros->data_ptr() : nullptr;
+ const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
+ int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr() : nullptr;
+
+ VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
+ if (isa == ISA::AMX) {
+ using gemm_t = cpu_micro_gemm::MicroGemm;
+ if (has_zp) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ if (use_desc_act) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ } else {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ } else if (isa == ISA::VEC) {
+ using gemm_t = cpu_micro_gemm::MicroGemm;
+ if (has_zp) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ if (use_desc_act) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ } else {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ }
+ });
+}
diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp
index 02a8072ccf306..cfb6e78cba9a1 100644
--- a/csrc/cpu/dnnl_helper.cpp
+++ b/csrc/cpu/dnnl_helper.cpp
@@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast(args), args.ab_type),
m_size_cache_(nullptr) {
- assert(ab_type_ == dnnl::memory::data_type::f32 ||
- ab_type_ == dnnl::memory::data_type::bf16 ||
- ab_type_ == dnnl::memory::data_type::f16);
+ assert(b_type_ == dnnl::memory::data_type::f32 ||
+ b_type_ == dnnl::memory::data_type::bf16 ||
+ b_type_ == dnnl::memory::data_type::f16);
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
new file mode 100644
index 0000000000000..87a019773a895
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
@@ -0,0 +1,245 @@
+#ifndef CPU_MICRO_GEMM_AMX_HPP
+#define CPU_MICRO_GEMM_AMX_HPP
+#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
+
+namespace cpu_micro_gemm {
+namespace {
+// AMX specific
+constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
+constexpr static int64_t AMX_TILE_ROW_NUM = 16;
+constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
+
+typedef struct __tile_config {
+ uint8_t palette_id = 1;
+ uint8_t start_row = 0;
+ uint8_t reserved_0[14] = {0};
+ uint16_t colsb[16] = {0};
+ uint8_t rows[16] = {0};
+} __tilecfg;
+
+// 2-2-4 pattern, for 16 < m <= 32
+// TILE 0, 1: load A matrix, row num should be 16, m - 16
+// TILE 2, 3: load B matrix, row num should be 16
+// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
+// - 16
+template
+class TileGemm224 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm224");
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm224");
+ }
+};
+
+template <>
+class TileGemm224 {
+ public:
+ using scalar_t = c10::BFloat16;
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
+ c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
+ c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
+ const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
+
+ // B is always packed as 16 output channels block
+ c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
+ c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
+ const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
+
+ float* __restrict__ c_tile_4 = c_ptr;
+ float* __restrict__ c_tile_5 =
+ c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
+ float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
+ float* __restrict__ c_tile_7 =
+ c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
+ const int32_t c_tile_stride = ldc * sizeof(float);
+
+ if (accum_c) {
+ _tile_loadd(4, c_tile_4, c_tile_stride);
+ _tile_loadd(5, c_tile_5, c_tile_stride);
+ _tile_loadd(6, c_tile_6, c_tile_stride);
+ _tile_loadd(7, c_tile_7, c_tile_stride);
+ } else {
+ _tile_zero(4);
+ _tile_zero(5);
+ _tile_zero(6);
+ _tile_zero(7);
+ }
+
+ for (int32_t k = 0; k < k_times; ++k) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_tile_stride);
+ _tile_dpbf16ps(4, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_tile_stride);
+ _tile_dpbf16ps(5, 0, 3);
+ _tile_loadd(1, a_tile_1, a_tile_stride);
+ _tile_dpbf16ps(6, 1, 2);
+ _tile_dpbf16ps(7, 1, 3);
+
+ // update ptrs
+ a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ }
+
+ _tile_stored(4, c_tile_4, c_tile_stride);
+ _tile_stored(5, c_tile_5, c_tile_stride);
+ _tile_stored(6, c_tile_6, c_tile_stride);
+ _tile_stored(7, c_tile_7, c_tile_stride);
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ const int32_t m_0 = AMX_TILE_ROW_NUM;
+ const int32_t m_1 = m - AMX_TILE_ROW_NUM;
+ config.rows[0] = m_0;
+ config.rows[1] = m_1;
+ config.rows[2] = AMX_TILE_ROW_NUM;
+ config.rows[3] = AMX_TILE_ROW_NUM;
+ config.rows[4] = m_0;
+ config.rows[5] = m_0;
+ config.rows[6] = m_1;
+ config.rows[7] = m_1;
+ _tile_loadconfig(&config);
+ }
+};
+
+// 1-2-2 pattern, for 0 < m <= 16
+// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
+// m, m
+// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
+// num should be 16
+// TILE 6, 7, (6, 7): store results C matrix, row num should be
+// m
+template
+class TileGemm122 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm122");
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm122");
+ }
+};
+
+template <>
+class TileGemm122 {
+ public:
+ using scalar_t = c10::BFloat16;
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
+ c10::BFloat16* __restrict__ a_tile_1 =
+ a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
+
+ c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
+ c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
+ c10::BFloat16* __restrict__ b_tile_4 =
+ b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ c10::BFloat16* __restrict__ b_tile_5 =
+ b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ int64_t b_stride = AMX_TILE_ROW_BYTES;
+
+ float* __restrict__ c_tile_6 = c_ptr;
+ float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
+ int64_t c_stride = ldc * sizeof(float);
+
+ const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
+ const int32_t k_group_times = k_times / 2;
+ const bool has_tail = (k_times % 2 == 1);
+
+ if (accum_c) {
+ _tile_loadd(6, c_tile_6, c_stride);
+ _tile_loadd(7, c_tile_7, c_stride);
+ } else {
+ _tile_zero(6);
+ _tile_zero(7);
+ }
+
+ for (int32_t k = 0; k < k_group_times; ++k) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_stride);
+ _tile_dpbf16ps(6, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_stride);
+ _tile_dpbf16ps(7, 0, 3);
+ _tile_loadd(1, a_tile_1, a_tile_stride);
+ _tile_stream_loadd(4, b_tile_4, b_stride);
+ _tile_dpbf16ps(6, 1, 4);
+ _tile_stream_loadd(5, b_tile_5, b_stride);
+ _tile_dpbf16ps(7, 1, 5);
+
+ // update ptrs
+ a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ }
+
+ if (has_tail) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_stride);
+ _tile_dpbf16ps(6, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_stride);
+ _tile_dpbf16ps(7, 0, 3);
+ }
+
+ _tile_stored(6, c_tile_6, c_stride);
+ _tile_stored(7, c_tile_7, c_stride);
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ config.rows[0] = m;
+ config.rows[1] = m;
+ config.rows[2] = AMX_TILE_ROW_NUM;
+ config.rows[3] = AMX_TILE_ROW_NUM;
+ config.rows[4] = AMX_TILE_ROW_NUM;
+ config.rows[5] = AMX_TILE_ROW_NUM;
+ config.rows[6] = m;
+ config.rows[7] = m;
+ _tile_loadconfig(&config);
+ }
+};
+} // namespace
+
+// Gemm kernel uses AMX, requires B matrix to be packed
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 32;
+ static constexpr int32_t NSize = 32;
+
+ public:
+ MicroGemm() : curr_m_(-1) {
+ vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; });
+ }
+
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ if (m > AMX_TILE_ROW_NUM) {
+ if (m != curr_m_) {
+ curr_m_ = m;
+ TileGemm224::init_tile_config(m, amx_tile_config_);
+ }
+ TileGemm224::gemm(CPU_MICRO_GEMM_PARAMS);
+ } else {
+ if (m != curr_m_) {
+ curr_m_ = m;
+ TileGemm122::init_tile_config(m, amx_tile_config_);
+ }
+ TileGemm122::gemm(CPU_MICRO_GEMM_PARAMS);
+ }
+ }
+
+ private:
+ alignas(64) __tilecfg amx_tile_config_;
+ int32_t curr_m_;
+};
+
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
new file mode 100644
index 0000000000000..784da55a420e5
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
@@ -0,0 +1,91 @@
+#ifndef CPU_MICRO_GEMM_IMPL_HPP
+#define CPU_MICRO_GEMM_IMPL_HPP
+#include "cpu/utils.hpp"
+#include "cpu/cpu_types.hpp"
+
+namespace cpu_micro_gemm {
+#define DEFINE_CPU_MICRO_GEMM_PARAMS \
+ scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
+ float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
+ const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
+ const bool accum_c
+
+#define CPU_MICRO_GEMM_PARAMS \
+ a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
+
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 16;
+ static constexpr int32_t NSize = 16;
+
+ public:
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unimplemented MicroGemm.");
+ }
+};
+
+template
+FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
+ scalar_t* __restrict__ d_ptr,
+ const int32_t m, const int64_t ldc,
+ const int64_t ldd) {
+ using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+ static_assert(n_size % 16 == 0);
+
+ float* __restrict__ curr_c = c_ptr;
+ scalar_t* __restrict__ curr_d = d_ptr;
+ for (int32_t i = 0; i < m; ++i) {
+ float* __restrict__ curr_c_iter = curr_c;
+ scalar_t* __restrict__ curr_d_iter = curr_d;
+ vec_op::unroll_loop([&](int32_t n_g_idx) {
+ vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
+ scalar_vec_t c_vec(c_vec_fp32);
+ c_vec.save(curr_d_iter);
+ curr_c_iter += 16;
+ curr_d_iter += 16;
+ });
+ curr_c += ldc;
+ curr_d += ldd;
+ }
+}
+
+template
+FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
+ scalar_t* __restrict__ d_ptr,
+ scalar_t* __restrict__ bias_ptr,
+ const int32_t m, const int64_t ldc,
+ const int64_t ldd) {
+ using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+ static_assert(n_size % 16 == 0);
+ constexpr int32_t n_group_num = n_size / 16;
+ static_assert(n_group_num <= 16);
+
+ vec_op::FP32Vec16 bias_vecs[n_group_num];
+ scalar_t* __restrict__ curr_bias = bias_ptr;
+ vec_op::unroll_loop([&](int32_t i) {
+ scalar_vec_t vec(curr_bias);
+ bias_vecs[i] = vec_op::FP32Vec16(vec);
+ curr_bias += 16;
+ });
+
+ float* __restrict__ curr_c = c_ptr;
+ scalar_t* __restrict__ curr_d = d_ptr;
+ for (int32_t i = 0; i < m; ++i) {
+ float* __restrict__ curr_c_iter = curr_c;
+ scalar_t* __restrict__ curr_d_iter = curr_d;
+ vec_op::unroll_loop([&](int32_t n_g_idx) {
+ vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
+ c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
+ scalar_vec_t c_vec(c_vec_fp32);
+ c_vec.save(curr_d_iter);
+ curr_c_iter += 16;
+ curr_d_iter += 16;
+ });
+ curr_c += ldc;
+ curr_d += ldd;
+ }
+}
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
new file mode 100644
index 0000000000000..3985c2f2e5fe4
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
@@ -0,0 +1,115 @@
+#ifndef CPU_MICRO_GEMM_VEC_HPP
+#define CPU_MICRO_GEMM_VEC_HPP
+#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
+
+namespace cpu_micro_gemm {
+namespace {
+// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
+template
+class TileGemm82 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ switch (m) {
+ case 1:
+ gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 2:
+ gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 3:
+ gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 4:
+ gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 5:
+ gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 6:
+ gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 7:
+ gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 8:
+ gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ }
+ }
+
+ template
+ static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ static_assert(0 < M <= 8);
+ using load_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+
+ scalar_t* __restrict__ curr_b_0 = b_ptr;
+ scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
+ float* __restrict__ curr_c_0 = c_ptr;
+ float* __restrict__ curr_c_1 = c_ptr + 16;
+
+ vec_op::FP32Vec16 c_regs[M * 2];
+ if (accum_c) {
+ float* __restrict__ curr_m_c_0 = curr_c_0;
+ float* __restrict__ curr_m_c_1 = curr_c_1;
+ vec_op::unroll_loop([&](int32_t i) {
+ c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
+ c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
+
+ // update
+ curr_m_c_0 += ldc;
+ curr_m_c_1 += ldc;
+ });
+ }
+
+ scalar_t* __restrict__ curr_a = a_ptr;
+ for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
+ load_vec_t b_0_reg(curr_b_0);
+ vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
+ load_vec_t b_1_reg(curr_b_1);
+ vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
+
+ scalar_t* __restrict__ curr_m_a = curr_a;
+ vec_op::unroll_loop([&](int32_t i) {
+ scalar_t v = *curr_m_a;
+ load_vec_t a_reg_original(v);
+ vec_op::FP32Vec16 a_reg(a_reg_original);
+ c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
+ c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
+
+ // update
+ curr_m_a += lda;
+ });
+
+ // update
+ curr_a += 1;
+ curr_b_0 += 16;
+ curr_b_1 += 16;
+ }
+
+ vec_op::unroll_loop([&](int32_t i) {
+ c_regs[i * 2].save(curr_c_0);
+ c_regs[i * 2 + 1].save(curr_c_1);
+
+ // update
+ curr_c_0 += ldc;
+ curr_c_1 += ldc;
+ });
+ }
+};
+} // namespace
+
+// Gemm kernel uses vector instructions, requires B matrix to be packed
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 8;
+ static constexpr int32_t NSize = 32;
+
+ public:
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TileGemm82::gemm(CPU_MICRO_GEMM_PARAMS);
+ }
+};
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 5e2aa70692566..b07d20bab7dd9 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -100,6 +100,16 @@ void cpu_attention_with_kv_cache(
const torch::Tensor& scheduler_metadata,
const std::optional& s_aux);
+// Note: just for avoiding importing errors
+void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
+
+void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
+ torch::Tensor& output, const torch::Tensor& scales,
+ const std::optional& zeros,
+ const std::optional& g_idx,
+ const std::optional& bias,
+ const int64_t pack_factor, const std::string& isa_hint);
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -275,6 +285,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
&cpu_attention_with_kv_cache);
+
+ // placeholders
+ ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
+ ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
+ ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
+
+ // WNA16
+#if defined(__AVX512F__)
+ ops.def(
+ "cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
+ "Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
+ "pack_factor, str isa_hint) -> ()");
+ ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
+#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
diff --git a/csrc/cpu/utils.hpp b/csrc/cpu/utils.hpp
new file mode 100644
index 0000000000000..d8399c56f6af8
--- /dev/null
+++ b/csrc/cpu/utils.hpp
@@ -0,0 +1,55 @@
+#ifndef UTILS_HPP
+#define UTILS_HPP
+
+#include
+#include
+#include
+#include
+
+#include "cpu_types.hpp"
+
+namespace cpu_utils {
+enum class ISA { AMX, VEC };
+
+template
+struct VecTypeTrait {
+ using vec_t = void;
+};
+
+template <>
+struct VecTypeTrait {
+ using vec_t = vec_op::FP32Vec16;
+};
+
+template <>
+struct VecTypeTrait {
+ using vec_t = vec_op::BF16Vec16;
+};
+
+template <>
+struct VecTypeTrait {
+ using vec_t = vec_op::FP16Vec16;
+};
+
+struct Counter {
+ std::atomic counter;
+ char _padding[56];
+
+ Counter() : counter(0) {}
+
+ void reset_counter() { counter.store(0); }
+
+ int64_t acquire_counter() { return counter++; }
+};
+
+inline int64_t get_l2_size() {
+ static int64_t size = []() {
+ long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
+ assert(l2_cache_size != -1);
+ return l2_cache_size >> 1; // use 50% of L2 cache
+ }();
+ return size;
+}
+} // namespace cpu_utils
+
+#endif
diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu
index 5369d409f9b21..aff11326d78e9 100644
--- a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu
+++ b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu
@@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) {
});
if (numel % 256 != 0) {
- out = out.index({torch::indexing::Slice(0, numel / had_size)});
+ out = out.narrow(0, 0, numel / had_size);
}
if (inplace && out.data_ptr() != x.data_ptr()) {
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 4e6ef8f5ca13c..5d5b82c4fa5af 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -14,6 +14,7 @@ RUN apt clean && apt-get update -y && \
libxext6 \
libgl1 \
lsb-release \
+ libaio-dev \
numactl \
wget \
vim \
@@ -68,8 +69,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
RUN python3 -m pip install -e tests/vllm_test_utils
# install nixl from source code
+ENV NIXL_VERSION=0.7.0
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/"
RUN --mount=type=cache,target=/root/.cache/pip \
pip uninstall oneccl oneccl-devel -y
diff --git a/docs/README.md b/docs/README.md
index 0608794e7e650..0c279c19f96ca 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -30,8 +30,8 @@ Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at
Where to get started with vLLM depends on the type of user. If you are looking to:
- Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md)
-- Build applications with vLLM, we recommend starting with the [User Guide](./usage)
-- Build vLLM, we recommend starting with [Developer Guide](./contributing)
+- Build applications with vLLM, we recommend starting with the [User Guide](./usage/README.md)
+- Build vLLM, we recommend starting with [Developer Guide](./contributing/README.md)
For information about the development of vLLM, see:
diff --git a/docs/cli/bench/latency.md b/docs/cli/bench/latency.md
index 21ab13e63781a..ea7ea7321ffcd 100644
--- a/docs/cli/bench/latency.md
+++ b/docs/cli/bench/latency.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_latency.md"
+--8<-- "docs/argparse/bench_latency.inc.md"
diff --git a/docs/cli/bench/serve.md b/docs/cli/bench/serve.md
index f7c415c6becb5..f7dc8036cc262 100644
--- a/docs/cli/bench/serve.md
+++ b/docs/cli/bench/serve.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_serve.md"
+--8<-- "docs/argparse/bench_serve.inc.md"
diff --git a/docs/cli/bench/sweep/plot.md b/docs/cli/bench/sweep/plot.md
index f29bffb64655c..a101330e093cc 100644
--- a/docs/cli/bench/sweep/plot.md
+++ b/docs/cli/bench/sweep/plot.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_sweep_plot.md"
+--8<-- "docs/argparse/bench_sweep_plot.inc.md"
diff --git a/docs/cli/bench/sweep/serve.md b/docs/cli/bench/sweep/serve.md
index 5b5f91a951ed0..f0468f06fc287 100644
--- a/docs/cli/bench/sweep/serve.md
+++ b/docs/cli/bench/sweep/serve.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_sweep_serve.md"
+--8<-- "docs/argparse/bench_sweep_serve.inc.md"
diff --git a/docs/cli/bench/sweep/serve_sla.md b/docs/cli/bench/sweep/serve_sla.md
index 5f8ab6005e50b..5642ec67eb007 100644
--- a/docs/cli/bench/sweep/serve_sla.md
+++ b/docs/cli/bench/sweep/serve_sla.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_sweep_serve_sla.md"
+--8<-- "docs/argparse/bench_sweep_serve_sla.inc.md"
diff --git a/docs/cli/bench/throughput.md b/docs/cli/bench/throughput.md
index e4ff5ce43c9ce..e7f618fb4d147 100644
--- a/docs/cli/bench/throughput.md
+++ b/docs/cli/bench/throughput.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/bench_throughput.md"
+--8<-- "docs/argparse/bench_throughput.inc.md"
diff --git a/docs/cli/chat.md b/docs/cli/chat.md
index b006cb8de60d0..0246bd431b101 100644
--- a/docs/cli/chat.md
+++ b/docs/cli/chat.md
@@ -1,5 +1,5 @@
# vllm chat
-## Options
+## Arguments
---8<-- "docs/argparse/chat.md"
+--8<-- "docs/argparse/chat.inc.md"
diff --git a/docs/cli/complete.md b/docs/cli/complete.md
index 400359acf4fb8..eb2ffdaabac25 100644
--- a/docs/cli/complete.md
+++ b/docs/cli/complete.md
@@ -1,5 +1,5 @@
# vllm complete
-## Options
+## Arguments
---8<-- "docs/argparse/complete.md"
+--8<-- "docs/argparse/complete.inc.md"
diff --git a/docs/cli/run-batch.md b/docs/cli/run-batch.md
index f7d401b8dad2b..758fbda283978 100644
--- a/docs/cli/run-batch.md
+++ b/docs/cli/run-batch.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/run-batch.md"
+--8<-- "docs/argparse/run-batch.inc.md"
diff --git a/docs/cli/serve.md b/docs/cli/serve.md
index 2c8f9d320f5df..35652fec587b3 100644
--- a/docs/cli/serve.md
+++ b/docs/cli/serve.md
@@ -4,6 +4,6 @@
--8<-- "docs/cli/json_tip.inc.md"
-## Options
+## Arguments
---8<-- "docs/argparse/serve.md"
+--8<-- "docs/argparse/serve.inc.md"
diff --git a/docs/configuration/serve_args.md b/docs/configuration/serve_args.md
index c1cc5577bc7ab..baaf21f01f066 100644
--- a/docs/configuration/serve_args.md
+++ b/docs/configuration/serve_args.md
@@ -5,7 +5,7 @@ The `vllm serve` command is used to launch the OpenAI-compatible server.
## CLI Arguments
The `vllm serve` command is used to launch the OpenAI-compatible server.
-To see the available options, take a look at the [CLI Reference](../cli/README.md#options)!
+To see the available options, take a look at the [CLI Reference](../cli/README.md)!
## Configuration file
diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md
index ec0dfc4199d17..c9bc9cfe28a35 100644
--- a/docs/contributing/benchmarks.md
+++ b/docs/contributing/benchmarks.md
@@ -983,7 +983,7 @@ each document has close to 512 tokens.
Please note that the `/v1/rerank` is also supported by embedding models. So if you're running
with an embedding model, also set `--no_reranker`. Because in this case the query is
-treated as a individual prompt by the server, here we send `random_batch_size - 1` documents
+treated as an individual prompt by the server, here we send `random_batch_size - 1` documents
to account for the extra prompt which is the query. The token accounting to report the
throughput numbers correctly is also adjusted.
diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md
index 7941b1f49ee8b..7634cc0859edf 100644
--- a/docs/contributing/profiling.md
+++ b/docs/contributing/profiling.md
@@ -224,6 +224,6 @@ snakeviz expensive_function.prof
Leverage VLLM_GC_DEBUG environment variable to debug GC costs.
-- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times
+- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elapsed times
- VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5
collected objects for each gc.collect
diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md
index aac7b76eea265..66bf3b27d1f52 100644
--- a/docs/design/cuda_graphs.md
+++ b/docs/design/cuda_graphs.md
@@ -128,7 +128,7 @@ A [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper] instance wrap
3. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, the wrapper will perform CUDA Graphs capture (if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
-The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and cenralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106).
+The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and centralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106).
#### Nested Wrapper design
diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md
index 2f4b17f191a5d..91ab4deae71df 100644
--- a/docs/design/io_processor_plugins.md
+++ b/docs/design/io_processor_plugins.md
@@ -1,6 +1,6 @@
# IO Processor Plugins
-IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
+IO Processor plugins are a feature that allows pre- and post-processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggered via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md
index acf7fc245462c..8eadeb386fcf2 100644
--- a/docs/design/logits_processors.md
+++ b/docs/design/logits_processors.md
@@ -411,7 +411,7 @@ Logits processor `update_state()` implementations should assume the following mo
* **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous
- * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
+ * **Shrink the batch:** a side effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch
@@ -548,7 +548,7 @@ Built-in logits processors are always loaded when the vLLM engine starts. See th
Review these logits processor implementations for guidance on writing built-in logits processors.
-Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforemented logits processor programming model.
+Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforementioned logits processor programming model.
* Allowed token IDs
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index ee224e6922fbd..7663b82266f0b 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -68,7 +68,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
## Fused MoE Experts Kernels
-The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adatpers so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
+The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adapters so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties.
Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx`, `DeepEPLLPrepareAndFinalize`.
diff --git a/docs/features/custom_arguments.md b/docs/features/custom_arguments.md
index 7a650d0e79c23..728a2c89901de 100644
--- a/docs/features/custom_arguments.md
+++ b/docs/features/custom_arguments.md
@@ -5,7 +5,7 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
!!! note
- Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
+ Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise, invalid custom arguments can cause unexpected behaviour.
## Offline Custom Arguments
diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md
index 52fcc44efacc5..5ddef9db1611b 100644
--- a/docs/features/custom_logitsprocs.md
+++ b/docs/features/custom_logitsprocs.md
@@ -71,7 +71,7 @@ Logits processor `update_state()` implementations should assume the following mo
* **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous
- * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
+ * **Shrink the batch:** a side effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots
5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch
@@ -286,7 +286,7 @@ Once you have created a custom subclass (like `WrappedPerReqLogitsProcessor`) wh
## Ways to Load Your Custom Logits Processor in vLLM
-Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits logits processors cannot be loaded on-demand for individual requests.
+Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits processors cannot be loaded on-demand for individual requests.
This section details different ways of making your logits processor visible to vLLM and triggering vLLM to load your logits processor.
@@ -438,7 +438,7 @@ The examples below show how a user would pass a custom argument (`target_token`)
## Best Practices for Writing Custom Logits Processors
-Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus it is important to implement these methods efficiently.
+Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus, it is important to implement these methods efficiently.
* Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity
* For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()`
@@ -465,4 +465,4 @@ Once vLLM loads a logits processor during initialization, then vLLM will invoke
* **Note:** for wrapped per-request logits processors, the `AdapterLogitsProcessor` base-class handles this by default
-* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method
+* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However, the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method
diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md
index 3e8cb87e37d33..fd4f249f2ec6c 100644
--- a/docs/features/disagg_prefill.md
+++ b/docs/features/disagg_prefill.md
@@ -91,6 +91,6 @@ Disaggregated prefilling is highly related to infrastructure, so vLLM relies on
We recommend three ways of implementations:
-- **Fully-customized connector**: Implement your own `Connector`, and call third-party libraries to send and receive KV caches, and many many more (like editing vLLM's model input to perform customized prefilling, etc). This approach gives you the most control, but at the risk of being incompatible with future vLLM versions.
+- **Fully-customized connector**: Implement your own `Connector`, and call third-party libraries to send and receive KV caches, and many many more (like editing vLLM's model input to perform customized prefilling, etc.). This approach gives you the most control, but at the risk of being incompatible with future vLLM versions.
- **Database-like connector**: Implement your own `LookupBuffer` and support the `insert` and `drop_select` APIs just like SQL.
- **Distributed P2P connector**: Implement your own `Pipe` and support the `send_tensor` and `recv_tensor` APIs, just like `torch.distributed`.
diff --git a/docs/features/lora.md b/docs/features/lora.md
index 3a85b52d89b68..d42a3cef76bde 100644
--- a/docs/features/lora.md
+++ b/docs/features/lora.md
@@ -4,7 +4,7 @@ This document shows you how to use [LoRA adapters](https://arxiv.org/abs/2106.09
LoRA adapters can be used with any vLLM model that implements [SupportsLoRA][vllm.model_executor.models.interfaces.SupportsLoRA].
-Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
+Adapters can be efficiently served on a per-request basis with minimal overhead. First we download the adapter(s) and save
them locally with
```python
diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md
index cde2ec165712b..5f684604e6031 100644
--- a/docs/features/multimodal_inputs.md
+++ b/docs/features/multimodal_inputs.md
@@ -483,7 +483,7 @@ Then, you can use the OpenAI client as follows:
)
# Single-image input inference
- image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_response = client.chat.completions.create(
model="microsoft/Phi-3.5-vision-instruct",
diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md
index be0702f4c9e16..bd7bc186e13aa 100644
--- a/docs/features/quantization/quark.md
+++ b/docs/features/quantization/quark.md
@@ -298,7 +298,7 @@ There are two steps to generate and deploy a mixed precision model quantized wit
Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.
-As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:
+As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benefits. They are:
- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md
index e8bfca0e5e88f..d1beab7855b18 100644
--- a/docs/getting_started/installation/cpu.md
+++ b/docs/getting_started/installation/cpu.md
@@ -97,14 +97,13 @@ Currently, there are no pre-built CPU wheels.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
-- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
## FAQ
### Which `dtype` should be used?
-- Currently vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem.
+- Currently, vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem.
### How to launch a vLLM service on CPU?
@@ -191,10 +190,9 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel
- GPTQ (x86 only)
- compressed-tensor INT8 W8A8 (x86, s390x)
-### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
+### (x86 only) What is the purpose of `VLLM_CPU_SGL_KERNEL`?
- Both of them require `amx` CPU flag.
- - `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models
- `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios.
### Why do I see `get_mempolicy: Operation not permitted` when running in Docker?
diff --git a/docs/getting_started/installation/cpu.s390x.inc.md b/docs/getting_started/installation/cpu.s390x.inc.md
index 442c2b4ec64e8..c2163139a7c5d 100644
--- a/docs/getting_started/installation/cpu.s390x.inc.md
+++ b/docs/getting_started/installation/cpu.s390x.inc.md
@@ -2,7 +2,7 @@
vLLM has experimental support for s390x architecture on IBM Z platform. For now, users must build from source to natively run on IBM Z platform.
-Currently the CPU implementation for s390x architecture supports FP32 datatype only.
+Currently, the CPU implementation for s390x architecture supports FP32 datatype only.
!!! warning
There are no pre-built wheels or images for this device, so you must build vLLM from source.
diff --git a/docs/getting_started/installation/cpu.x86.inc.md b/docs/getting_started/installation/cpu.x86.inc.md
index 00f3b726b1a0e..310f179cb89ca 100644
--- a/docs/getting_started/installation/cpu.x86.inc.md
+++ b/docs/getting_started/installation/cpu.x86.inc.md
@@ -83,7 +83,7 @@ uv pip install dist/*.whl
!!! example "Troubleshooting"
- **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`.
- **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed.
- - `AMD` requies at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU.
+ - `AMD` requires at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU.
- If you receive an error such as: `Could not find a version that satisfies the requirement torch==X.Y.Z+cpu+cpu`, consider updating [pyproject.toml](https://github.com/vllm-project/vllm/blob/main/pyproject.toml) to help pip resolve the dependency.
```toml title="pyproject.toml"
[build-system]
diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py
index ce1c5c53cf35a..735074c08b8c8 100644
--- a/docs/mkdocs/hooks/generate_argparse.py
+++ b/docs/mkdocs/hooks/generate_argparse.py
@@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import importlib
+import importlib.metadata
+import importlib.util
import logging
import sys
import traceback
-from argparse import SUPPRESS, HelpFormatter
+from argparse import SUPPRESS, Action, HelpFormatter
+from collections.abc import Iterable
+from importlib.machinery import ModuleSpec
from pathlib import Path
-from typing import Literal
+from typing import TYPE_CHECKING, Literal
from unittest.mock import MagicMock, patch
from pydantic_core import core_schema
@@ -19,6 +22,11 @@ ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
sys.path.insert(0, str(ROOT_DIR))
+def mock_if_no_torch(mock_module: str, mock: MagicMock):
+ if not importlib.util.find_spec("torch"):
+ sys.modules[mock_module] = mock
+
+
# Mock custom op code
class MockCustomOp:
@staticmethod
@@ -29,18 +37,21 @@ class MockCustomOp:
return decorator
-noop = lambda *a, **k: None
-sys.modules["vllm._C"] = MagicMock()
-sys.modules["vllm.model_executor.custom_op"] = MagicMock(CustomOp=MockCustomOp)
-sys.modules["vllm.utils.torch_utils"] = MagicMock(direct_register_custom_op=noop)
+mock_if_no_torch("vllm._C", MagicMock())
+mock_if_no_torch("vllm.model_executor.custom_op", MagicMock(CustomOp=MockCustomOp))
+mock_if_no_torch(
+ "vllm.utils.torch_utils", MagicMock(direct_register_custom_op=lambda *a, **k: None)
+)
+
# Mock any version checks by reading from compiled CI requirements
with open(ROOT_DIR / "requirements/test.txt") as f:
VERSIONS = dict(line.strip().split("==") for line in f if "==" in line)
importlib.metadata.version = lambda name: VERSIONS.get(name) or "0.0.0"
+
# Make torch.nn.Parameter safe to inherit from
-sys.modules["torch.nn"] = MagicMock(Parameter=object)
+mock_if_no_torch("torch.nn", MagicMock(Parameter=object))
class PydanticMagicMock(MagicMock):
@@ -49,31 +60,34 @@ class PydanticMagicMock(MagicMock):
def __init__(self, *args, **kwargs):
name = kwargs.pop("name", None)
super().__init__(*args, **kwargs)
- self.__spec__ = importlib.machinery.ModuleSpec(name, None)
+ self.__spec__ = ModuleSpec(name, None)
def __get_pydantic_core_schema__(self, source_type, handler):
return core_schema.any_schema()
-def auto_mock(module, attr, max_mocks=100):
+def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
"""Function that automatically mocks missing modules during imports."""
- logger.info("Importing %s from %s", attr, module)
+ logger.info("Importing %s from %s", attr, module_name)
+
for _ in range(max_mocks):
try:
+ module = importlib.import_module(module_name)
+
# First treat attr as an attr, then as a submodule
- return getattr(
- importlib.import_module(module),
- attr,
- importlib.import_module(f"{module}.{attr}"),
- )
+ if hasattr(module, attr):
+ return getattr(module, attr)
+
+ return importlib.import_module(f"{module_name}.{attr}")
except ModuleNotFoundError as e:
+ assert e.name is not None
logger.info("Mocking %s for argparse doc generation", e.name)
sys.modules[e.name] = PydanticMagicMock(name=e.name)
- except Exception as e:
- logger.warning("Failed to import %s.%s: %s", module, attr, e)
+ except Exception:
+ logger.exception("Failed to import %s.%s: %s", module_name, attr)
raise ImportError(
- f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
+ f"Failed to import {module_name}.{attr} after mocking {max_mocks} imports"
)
@@ -91,21 +105,26 @@ ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
-FlexibleArgumentParser = auto_mock(
- "vllm.utils.argparse_utils", "FlexibleArgumentParser"
-)
+
+if TYPE_CHECKING:
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+else:
+ FlexibleArgumentParser = auto_mock(
+ "vllm.utils.argparse_utils", "FlexibleArgumentParser"
+ )
class MarkdownFormatter(HelpFormatter):
"""Custom formatter that generates markdown for argument groups."""
- def __init__(self, prog, starting_heading_level=3):
- super().__init__(prog, max_help_position=float("inf"), width=float("inf"))
+ def __init__(self, prog: str, starting_heading_level: int = 3):
+ super().__init__(prog, max_help_position=sys.maxsize, width=sys.maxsize)
+
self._section_heading_prefix = "#" * starting_heading_level
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
self._markdown_output = []
- def start_section(self, heading):
+ def start_section(self, heading: str):
if heading not in {"positional arguments", "options"}:
heading_md = f"\n{self._section_heading_prefix} {heading}\n\n"
self._markdown_output.append(heading_md)
@@ -113,14 +132,14 @@ class MarkdownFormatter(HelpFormatter):
def end_section(self):
pass
- def add_text(self, text):
+ def add_text(self, text: str):
if text:
self._markdown_output.append(f"{text.strip()}\n\n")
def add_usage(self, usage, actions, groups, prefix=None):
pass
- def add_arguments(self, actions):
+ def add_arguments(self, actions: Iterable[Action]):
for action in actions:
if len(action.option_strings) == 0 or "--help" in action.option_strings:
continue
@@ -169,7 +188,7 @@ def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
# Auto-mock runtime imports
if tb_list := traceback.extract_tb(e.__traceback__):
path = Path(tb_list[-1].filename).relative_to(ROOT_DIR)
- auto_mock(module=".".join(path.parent.parts), attr=path.stem)
+ auto_mock(module_name=".".join(path.parent.parts), attr=path.stem)
return create_parser(add_cli_args, **kwargs)
else:
raise e
@@ -209,7 +228,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Generate documentation for each parser
for stem, parser in parsers.items():
- doc_path = ARGPARSE_DOC_DIR / f"{stem}.md"
+ doc_path = ARGPARSE_DOC_DIR / f"{stem}.inc.md"
# Specify encoding for building on Windows
with open(doc_path, "w", encoding="utf-8") as f:
f.write(super(type(parser), parser).format_help())
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 2ce1cd9a943bf..3c9295b6414a7 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -351,6 +351,7 @@ th {
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------|
+| `AfmoeForCausalLM` | Afmoe | TBA | ✅︎ | ✅︎ |
| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ |
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ |
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ |
@@ -670,7 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
-| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
+| `Gemma3ForConditionalGeneration` | Gemma 3 | T + IE+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
@@ -685,7 +686,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
-| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ |
+| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ |
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ |
diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md
index 821628e6e3174..23df3963823aa 100644
--- a/docs/serving/openai_compatible_server.md
+++ b/docs/serving/openai_compatible_server.md
@@ -293,7 +293,7 @@ and passing a list of `messages` in the request. Refer to the examples below for
base_url="http://localhost:8000/v1",
api_key="EMPTY",
)
- image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = create_chat_embeddings(
client,
diff --git a/docs/usage/README.md b/docs/usage/README.md
index 0c63d01f0f99f..4e8ece2c06052 100644
--- a/docs/usage/README.md
+++ b/docs/usage/README.md
@@ -1,6 +1,6 @@
# Using vLLM
-First, vLLM must be [installed](../getting_started/installation/) for your chosen device in either a Python or Docker environment.
+First, vLLM must be [installed](../getting_started/installation/README.md) for your chosen device in either a Python or Docker environment.
Then, vLLM supports the following usage patterns:
diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py
index 0c09e603271de..6f05968ce065e 100644
--- a/examples/offline_inference/rlhf.py
+++ b/examples/offline_inference/rlhf.py
@@ -62,7 +62,7 @@ ray.init()
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
-# https://docs.ray.io/en/latest/placement-groups.html
+# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
diff --git a/examples/offline_inference/rlhf_online_quant.py b/examples/offline_inference/rlhf_online_quant.py
new file mode 100644
index 0000000000000..2d98ad22c589e
--- /dev/null
+++ b/examples/offline_inference/rlhf_online_quant.py
@@ -0,0 +1,162 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
+
+The script separates training and inference workloads onto distinct GPUs
+so that Ray can manage process placement and inter-process communication.
+A Hugging Face Transformer model occupies GPU 0 for training, whereas a
+tensor-parallel vLLM inference engine occupies GPU 1–2.
+
+The example performs the following steps:
+
+* Load the training model on GPU 0.
+* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
+ and Ray placement groups.
+* Generate text from a list of prompts using the inference engine.
+* Update the weights of the training model and broadcast the updated weights
+ to the inference engine by using a Ray collective RPC group. Note that
+ for demonstration purposes we simply zero out the weights.
+
+For a production-ready implementation that supports multiple training and
+inference replicas, see the OpenRLHF framework:
+https://github.com/OpenRLHF/OpenRLHF
+
+This example assumes a single-node cluster with three GPUs, but Ray
+supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
+workloads. Residual GPU activity interferes with vLLM memory profiling and
+causes unexpected behavior.
+"""
+
+import json
+import os
+
+import ray
+import torch
+from ray.util.placement_group import placement_group
+from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+from rlhf_utils import stateless_init_process_group
+from torchao.core.config import config_to_dict
+from torchao.quantization import (
+ Float8DynamicActivationFloat8WeightConfig,
+ PerRow,
+)
+from transformers import AutoModelForCausalLM
+
+from vllm import LLM, SamplingParams
+from vllm.utils.network_utils import get_ip, get_open_port
+
+
+class MyLLM(LLM):
+ """Configure the vLLM worker for Ray placement group execution."""
+
+ def __init__(self, *args, **kwargs):
+ # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
+ # so that vLLM can manage its own device placement within the worker.
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
+ super().__init__(*args, **kwargs)
+
+
+# Load the OPT-125M model onto GPU 0 for the training workload.
+train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
+train_model.to("cuda:0")
+
+# Initialize Ray and set the visible devices. The vLLM engine will
+# be placed on GPUs 1 and 2.
+os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
+ray.init()
+
+# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
+# Learn more about Ray placement groups:
+# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
+pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
+ray.get(pg_inference.ready())
+scheduling_inference = PlacementGroupSchedulingStrategy(
+ placement_group=pg_inference,
+ placement_group_capture_child_tasks=True,
+ placement_group_bundle_index=0,
+)
+
+# Launch the vLLM inference engine. The `enforce_eager` flag reduces
+# start-up latency.
+
+# generate torchao quantization config for RL rollout
+# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
+# use serialized config files instead of passing around json string
+config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
+
+json_str = json.dumps(config_to_dict(config))
+
+llm = ray.remote(
+ num_cpus=0,
+ num_gpus=0,
+ scheduling_strategy=scheduling_inference,
+)(MyLLM).remote(
+ model="facebook/opt-125m",
+ hf_overrides={"quantization_config_dict_json": json_str},
+ enforce_eager=True,
+ worker_extension_cls="rlhf_utils.WorkerExtension",
+ tensor_parallel_size=2,
+ distributed_executor_backend="ray",
+)
+
+# Generate text from the prompts.
+prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+]
+
+sampling_params = SamplingParams(temperature=0)
+
+outputs = ray.get(llm.generate.remote(prompts, sampling_params))
+
+print("-" * 50)
+for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
+ print("-" * 50)
+
+# Set up the communication channel between the training process and the
+# inference engine.
+master_address = get_ip()
+master_port = get_open_port()
+
+handle = llm.collective_rpc.remote(
+ "init_weight_update_group", args=(master_address, master_port, 1, 3)
+)
+
+model_update_group = stateless_init_process_group(
+ master_address, master_port, 0, 3, torch.device("cuda:0")
+)
+ray.get(handle)
+
+# Simulate a training step by zeroing out all model weights.
+# In a real RLHF training loop the weights would be updated using the gradient
+# from an RL objective such as PPO on a reward model.
+for name, p in train_model.named_parameters():
+ p.data.zero_()
+
+# Synchronize the updated weights to the inference engine.
+for name, p in train_model.named_parameters():
+ dtype_name = str(p.dtype).split(".")[-1]
+ handle = llm.collective_rpc.remote(
+ "update_weight", args=(name, dtype_name, p.shape)
+ )
+ model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
+ ray.get(handle)
+
+# Verify that the inference weights have been updated.
+assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
+
+# Generate text with the updated model. The output is expected to be nonsense
+# because the weights are zero.
+outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
+print("-" * 50)
+for output in outputs_updated:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
+ print("-" * 50)
diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py
index 63d85d5d9eef5..530aad4bc031c 100644
--- a/examples/offline_inference/vision_language_pooling.py
+++ b/examples/offline_inference/vision_language_pooling.py
@@ -266,7 +266,7 @@ def get_query(modality: QueryModality):
return ImageQuery(
modality="image",
image=fetch_image(
- "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501
+ "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
),
)
@@ -275,7 +275,7 @@ def get_query(modality: QueryModality):
modality="text+image",
text="A cat standing in the snow.",
image=fetch_image(
- "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501
+ "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg" # noqa: E501
),
)
diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py
index 520cbca003aa5..3d1259276998d 100644
--- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py
+++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py
@@ -66,7 +66,7 @@ def run_text_only(model: str, max_completion_tokens: int) -> None:
# Single-image input inference
def run_single_image(model: str, max_completion_tokens: int) -> None:
## Use image url in the payload
- image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[
{
diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
index 261b810ce5d03..47c2c5030078c 100644
--- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
+++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
@@ -21,7 +21,7 @@ from PIL import Image
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
-image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
def create_chat_embeddings(
diff --git a/format.sh b/format.sh
deleted file mode 100755
index 6ba93e0a19ba8..0000000000000
--- a/format.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/bin/bash
-
-echo "vLLM linting system has been moved from format.sh to pre-commit hooks."
-echo "Please run 'pip install -r requirements/lint.txt', followed by"
-echo "'pre-commit install' to install the pre-commit hooks."
-echo "Then linters will run automatically before each commit."
\ No newline at end of file
diff --git a/requirements/common.txt b/requirements/common.txt
index ad92ba3ad8278..1058ab91a02a5 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -30,7 +30,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
-gguf >= 0.13.0
+gguf >= 0.17.0
mistral_common[image] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt
index 331d02be6621e..81d429a5e5f8d 100644
--- a/requirements/cpu-build.txt
+++ b/requirements/cpu-build.txt
@@ -4,8 +4,9 @@ packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
--extra-index-url https://download.pytorch.org/whl/cpu
-torch==2.8.0+cpu; platform_machine == "x86_64"
-torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"
+torch==2.8.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
+torch==2.9.0; platform_system == "Darwin"
+torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)
wheel
jinja2>=3.1.6
diff --git a/requirements/cpu.txt b/requirements/cpu.txt
index d11787df4d92b..e23d3286f3f78 100644
--- a/requirements/cpu.txt
+++ b/requirements/cpu.txt
@@ -22,7 +22,6 @@ datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
-intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64"
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
# Use this to gather CPU info and optimize based on ARM Neoverse cores
diff --git a/setup.py b/setup.py
index e9b36e2a2e037..5591bcb132447 100644
--- a/setup.py
+++ b/setup.py
@@ -299,6 +299,20 @@ class cmake_build_ext(build_ext):
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
self.copy_file(file, dst_file)
+ if _is_cuda() or _is_hip():
+ # copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
+ # to current directory so that they can be included in the editable
+ # build
+ print(
+ f"Copying {self.build_lib}/vllm/third_party/triton_kernels "
+ "to vllm/third_party/triton_kernels"
+ )
+ shutil.copytree(
+ f"{self.build_lib}/vllm/third_party/triton_kernels",
+ "vllm/third_party/triton_kernels",
+ dirs_exist_ok=True,
+ )
+
class precompiled_build_ext(build_ext):
"""Disables extension building when using precompiled binaries."""
@@ -633,6 +647,9 @@ ext_modules = []
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
+ # Optional since this doesn't get built (produce an .so file). This is just
+ # copying the relevant .py files from the source repository.
+ ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py
index e1560efb3f247..f22d60ef000b2 100644
--- a/tests/compile/test_fusions_e2e.py
+++ b/tests/compile/test_fusions_e2e.py
@@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test
+is_blackwell = lambda: current_platform.is_device_capability(100)
+"""Are we running on Blackwell, a lot of tests depend on it"""
+
+
+class Matches(NamedTuple):
+ attention_fusion: int = 0
+ allreduce_fusion: int = 0
+ sequence_parallel: int = 0
+ async_tp: int = 0
+
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: AttentionBackendEnum
- attention_fusions: int
- allreduce_fusions: int | None = None
+ matches: Matches
MODELS_FP8: list[ModelBackendTestCase] = []
@@ -38,17 +47,33 @@ if current_platform.is_cuda():
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
- model_kwargs=dict(max_model_len=1024),
- backend=AttentionBackendEnum.TRITON_ATTN,
- attention_fusions=32,
- allreduce_fusions=65,
+ # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
+ # so FI attention+fp8_quant is at least tested once
+ model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
+ backend=AttentionBackendEnum.FLASHINFER
+ if is_blackwell()
+ else AttentionBackendEnum.TRITON_ATTN,
+ matches=Matches(
+ attention_fusion=32,
+ allreduce_fusion=65,
+ sequence_parallel=65,
+ async_tp=128,
+ ),
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
- backend=AttentionBackendEnum.FLASHINFER,
- attention_fusions=48,
- allreduce_fusions=96,
+ # TODO FlashInfer attn broken on Hopper with kvcache=fp8:
+ # https://github.com/vllm-project/vllm/issues/28568
+ # TODO FlashInfer attn broken on Blackwell for llama4:
+ # https://github.com/vllm-project/vllm/issues/28604
+ backend=AttentionBackendEnum.TRITON_ATTN,
+ matches=Matches(
+ attention_fusion=48,
+ allreduce_fusion=96,
+ sequence_parallel=96,
+ async_tp=95, # mlp is moe, no fusion there
+ ),
),
]
@@ -57,8 +82,12 @@ if current_platform.is_cuda():
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER,
- attention_fusions=32,
- allreduce_fusions=65,
+ matches=Matches(
+ attention_fusion=32,
+ allreduce_fusion=65,
+ sequence_parallel=65,
+ async_tp=128,
+ ),
),
]
@@ -68,15 +97,23 @@ if current_platform.is_cuda():
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
- attention_fusions=0,
- allreduce_fusions=65,
+ matches=Matches(
+ attention_fusion=0,
+ allreduce_fusion=65,
+ sequence_parallel=65,
+ async_tp=128,
+ ),
),
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
- attention_fusions=0,
- allreduce_fusions=97,
+ matches=Matches(
+ attention_fusion=0,
+ allreduce_fusion=97,
+ sequence_parallel=97,
+ async_tp=96, # MLP is MoE, half the fusions of dense
+ ),
),
]
@@ -86,19 +123,19 @@ elif current_platform.is_rocm():
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
- attention_fusions=32,
+ matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_ATTN,
- attention_fusions=32,
+ matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
- attention_fusions=32,
+ matches=Matches(attention_fusion=32),
),
]
@@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
@pytest.mark.parametrize(
- "model_name, model_kwargs, backend, "
- "attention_fusions, allreduce_fusions, custom_ops",
+ "model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
@@ -118,15 +154,14 @@ def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: AttentionBackendEnum,
- attention_fusions: int,
- allreduce_fusions: int,
+ matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == AttentionBackendEnum.FLASHINFER and (
- not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
+ not is_blackwell() or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@@ -169,12 +204,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
- matches = re.findall(
+ log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
- assert len(matches) == 1, log_holder.text
- assert int(matches[0]) == attention_fusions
+ assert len(log_matches) == 1, log_holder.text
+ assert int(log_matches[0]) == matches.attention_fusion
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
@@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
- "model_name, model_kwargs, backend, "
- "attention_fusions, allreduce_fusions, custom_ops",
+ "model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
@@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
- attention_fusions: int,
- allreduce_fusions: int,
+ matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
@@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
+ if "fp4" in model_name.lower() and not is_blackwell():
+ pytest.skip("NVFP4 quant requires Blackwell")
+
+ if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
+ # FlashInfer attn fusion requires Blackwell
+ matches = matches._replace(attention_fusion=0)
+
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
@@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
- matches = re.findall(
+ log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
- assert len(matches) == 2, log_holder.text
+ assert len(log_matches) == 2, log_holder.text
- assert int(matches[0]) == attention_fusions
- assert int(matches[1]) == attention_fusions
+ assert int(log_matches[0]) == matches.attention_fusion
+ assert int(log_matches[1]) == matches.attention_fusion
- matches = re.findall(
+ log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
- assert len(matches) == 2, log_holder.text
+ assert len(log_matches) == 2, log_holder.text
- assert int(matches[0]) == allreduce_fusions
- assert int(matches[1]) == allreduce_fusions
+ assert int(log_matches[0]) == matches.allreduce_fusion
+ assert int(log_matches[1]) == matches.allreduce_fusion
+
+
+@multi_gpu_test(num_gpus=2)
+@pytest.mark.parametrize(
+ "model_name, model_kwargs, backend, matches, custom_ops",
+ # Toggle RMSNorm and QuantFP8 for FP8 models
+ list(
+ flat_product(
+ MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
+ )
+ )
+ # Toggle RMSNorm for FP4 models and unquant models
+ + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
+)
+@pytest.mark.parametrize("inductor_graph_partition", [True, False])
+@pytest.mark.skipif(
+ not current_platform.is_cuda(),
+ reason="sequence parallel only tested on CUDA",
+)
+def test_tp2_attn_quant_async_tp(
+ model_name: str,
+ model_kwargs: dict,
+ backend: AttentionBackendEnum,
+ matches: Matches,
+ custom_ops: str,
+ inductor_graph_partition: bool,
+ caplog_mp_spawn,
+ monkeypatch,
+):
+ if is_blackwell():
+ # TODO: https://github.com/vllm-project/vllm/issues/27893
+ pytest.skip("Blackwell is not supported for AsyncTP pass")
+
+ if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
+ pytest.skip("Inductor graph partition requires torch>=2.9")
+
+ if "fp4" in model_name.lower() and not is_blackwell():
+ pytest.skip("NVFP4 quant requires Blackwell")
+
+ if backend == AttentionBackendEnum.FLASHINFER:
+ if not has_flashinfer():
+ pytest.skip("FlashInfer backend requires flashinfer installed")
+ if not is_blackwell():
+ # FlashInfer attn fusion requires Blackwell
+ matches = matches._replace(attention_fusion=0)
+
+ custom_ops_list = custom_ops.split(",") if custom_ops else []
+
+ if inductor_graph_partition:
+ mode = CUDAGraphMode.FULL_AND_PIECEWISE
+ splitting_ops: list[str] | None = None
+ else:
+ mode = CUDAGraphMode.FULL_DECODE_ONLY
+ splitting_ops = []
+
+ # Disable, compile cache to make sure custom passes run.
+ # Otherwise, we can't verify fusion happened through the logs.
+ monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
+
+ # To capture subprocess logs, we need to know whether spawn or fork is used.
+ # Force spawn as it is more general.
+ monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+ monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
+
+ compilation_config = CompilationConfig(
+ # Testing properties
+ use_inductor_graph_partition=inductor_graph_partition,
+ cudagraph_mode=mode,
+ custom_ops=custom_ops_list,
+ splitting_ops=splitting_ops,
+ # Common
+ level=CompilationMode.VLLM_COMPILE,
+ pass_config=PassConfig(
+ enable_attn_fusion=True,
+ enable_noop=True,
+ enable_sequence_parallelism=True,
+ enable_async_tp=True,
+ ),
+ # Inductor caches custom passes by default as well via uuid
+ inductor_compile_config={"force_disable_caches": True},
+ )
+
+ with caplog_mp_spawn(logging.DEBUG) as log_holder:
+ run_model(
+ compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
+ )
+ log_matches = re.findall(
+ r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
+ log_holder.text,
+ )
+ assert len(log_matches) == 2, log_holder.text
+
+ assert int(log_matches[0]) == matches.attention_fusion
+ assert int(log_matches[1]) == matches.attention_fusion
+
+ log_matches = re.findall(
+ r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
+ log_holder.text,
+ )
+ assert len(log_matches) == 2, log_holder.text
+
+ assert int(log_matches[0]) == matches.sequence_parallel
+ assert int(log_matches[1]) == matches.sequence_parallel
+
+ log_matches = re.findall(
+ r"collective_fusion.py:\d+] Replaced (\d+) patterns",
+ log_holder.text,
+ )
+ assert len(log_matches) == 2, log_holder.text
+
+ assert int(log_matches[0]) == matches.async_tp
+ assert int(log_matches[1]) == matches.async_tp
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py
index e909cf7393ad3..9cd7f64b04af5 100644
--- a/tests/compile/test_sequence_parallelism.py
+++ b/tests/compile/test_sequence_parallelism.py
@@ -5,15 +5,15 @@ import pytest
import torch
import vllm.envs as envs
-from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
-from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
+from vllm.compilation.fx_utils import find_auto_fn
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (
CompilationConfig,
+ CUDAGraphMode,
DeviceConfig,
ModelConfig,
PassConfig,
@@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
@@ -43,172 +44,157 @@ prompts = [
]
-class TestModel(torch.nn.Module):
- def __init__(self, hidden_size=16, intermediate_size=32):
+class TestAllReduceRMSNormModel(torch.nn.Module):
+ def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.gate_proj = torch.nn.Parameter(
- torch.empty((intermediate_size, hidden_size))
- )
- self.norm = RMSNorm(intermediate_size, 1e-05)
- # Initialize weights
- torch.nn.init.normal_(self.gate_proj, std=0.02)
+ self.eps = eps
+ self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
+ self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
- def forward(self, hidden_states, residual):
- """
- Forward pass implementing the operations in the FX graph
+ def forward(self, x):
+ z = torch.relu(x)
+ x = resid = tensor_model_parallel_all_reduce(z)
+ y = self.norm[0](x)
- Args:
- hidden_states: Input tensor
- residual: Residual tensor from previous layer
+ z2 = torch.mm(y, self.w[0])
+ x2 = tensor_model_parallel_all_reduce(z2)
- Returns:
- Tuple containing the output tensor
- """
- # Reshape input
- view = hidden_states.reshape(-1, self.hidden_size)
+ y2, resid = self.norm[1](x2, resid)
- # matrix multiplication
- permute = self.gate_proj.permute(1, 0)
- mm = torch.mm(view, permute)
+ z3 = torch.mm(y2, self.w[1])
+ x3 = tensor_model_parallel_all_reduce(z3)
- # Tensor parallel all-reduce
- all_reduce = tensor_model_parallel_all_reduce(mm)
+ y3, resid = self.norm[2](x3, resid)
- # layer normalization
- norm_output, residual_output = self.norm(all_reduce, residual)
+ z4 = torch.mm(y3, self.w[2])
+ x4 = tensor_model_parallel_all_reduce(z4)
- return norm_output, residual_output
+ y4, resid = self.norm[3](x4, resid)
+ return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [
- torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default,
+ torch.ops.vllm.reduce_scatter.default,
]
def ops_in_model(self):
- return [torch.ops._C.fused_add_rms_norm.default]
+ if RMSNorm.enabled():
+ return [
+ torch.ops._C.rms_norm.default,
+ torch.ops._C.fused_add_rms_norm.default,
+ ]
+ else:
+ return []
-class TestQuantModel(torch.nn.Module):
- def __init__(self, hidden_size=16, intermediate_size=32):
+class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
+ def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
self.vllm_config = get_current_vllm_config()
- self.gate_proj = torch.nn.Parameter(
- torch.empty((intermediate_size, hidden_size)), requires_grad=False
- )
- self.norm = RMSNorm(intermediate_size, 1e-05)
- # Initialize weights
- torch.nn.init.normal_(self.gate_proj, std=0.02)
+ self.hidden_size = hidden_size
+ self.eps = eps
+ self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
+ self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
+ self.w = [
+ torch.rand(hidden_size, hidden_size)
+ .to(dtype=current_platform.fp8_dtype())
+ .t()
+ for _ in range(3)
+ ]
- self.fp8_linear = Fp8LinearOp(act_quant_static=True)
-
- self.scale = torch.rand(1, dtype=torch.float32)
- # Create a weight that is compatible with torch._scaled_mm,
- # which expects a column-major layout.
- self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
- self.wscale = torch.rand(1, dtype=torch.float32)
-
- def forward(self, hidden_states, residual):
- """
- Forward pass implementing the operations in the FX graph
-
- Args:
- hidden_states: Input tensor
- residual: Residual tensor from previous layer
-
- Returns:
- Tuple containing the output tensor
- """
- # Reshape input
- view = hidden_states.reshape(-1, self.hidden_size)
-
- # matrix multiplication
- permute = self.gate_proj.permute(1, 0)
- mm = torch.mm(view, permute)
-
- # Tensor parallel all-reduce
- all_reduce = tensor_model_parallel_all_reduce(mm)
-
- # layer normalization
- norm_output, residual_output = self.norm(all_reduce, residual)
-
- # scaled_mm with static input quantization
- fp8_linear_result = self.fp8_linear.apply(
- norm_output,
- self.w,
- self.wscale,
- input_scale=self.scale.to(norm_output.device),
+ self.fp8_linear = Fp8LinearOp(
+ act_quant_static=True,
+ act_quant_group_shape=GroupShape.PER_TENSOR,
)
- return fp8_linear_result, residual_output
+ self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
- def ops_in_model_before(self):
- ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
- # The following are only removed if fusion happens
- if (
- self.vllm_config
- and self.vllm_config.compilation_config.pass_config.enable_fusion
- ):
- ops_to_remove.extend(
- [
- torch.ops._C.fused_add_rms_norm.default,
- torch.ops._C.static_scaled_fp8_quant.default,
- ]
- )
- return ops_to_remove
+ def forward(self, hidden_states):
+ # avoid having graph input be an arg to a pattern directly
+ z = torch.relu(hidden_states)
+ x = resid = tensor_model_parallel_all_reduce(z)
+ y = self.norm[0](x)
+
+ z2 = self.fp8_linear.apply(
+ y, self.w[0], self.wscale[0], input_scale=self.scale[0]
+ )
+
+ x2 = tensor_model_parallel_all_reduce(z2)
+ y2, resid = self.norm[1](x2, resid)
+
+ z3 = self.fp8_linear.apply(
+ y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
+ )
+
+ x3 = tensor_model_parallel_all_reduce(z3)
+ y3, resid = self.norm[2](x3, resid) # use resid here
+
+ z4 = self.fp8_linear.apply(
+ y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
+ )
+ x4 = tensor_model_parallel_all_reduce(z4)
+ y4, resid = self.norm[3](x4, resid) # use resid here
+ return y4
def ops_in_model_after(self):
- ops_to_add = [
- torch.ops.vllm.reduce_scatter.default,
+ return [
torch.ops.vllm.all_gather.default,
+ torch.ops.vllm.reduce_scatter.default,
+ ]
+
+ def ops_in_model_before(self):
+ return [
+ torch.ops.vllm.all_reduce.default,
]
- # The following is only added if fusion happens
- if (
- self.vllm_config
- and self.vllm_config.compilation_config.pass_config.enable_fusion
- ):
- ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
- return ops_to_add
def ops_in_model(self):
- if (
- self.vllm_config
- and self.vllm_config.compilation_config.pass_config.enable_fusion
- ):
- # If fusion happens, the fused op is the one
- # we check for (de)functionalization
+ if self.vllm_config.compilation_config.pass_config.enable_fusion:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
- else:
- # If no fusion, the original ops are checked
+ elif RMSNorm.enabled():
return [
torch.ops._C.fused_add_rms_norm.default,
- # TODO functionalization pass does not handle this yet
- # torch.ops._C.static_scaled_fp8_quant.default,
]
+ elif self.fp8_linear.quant_fp8.enabled():
+ return [
+ torch.ops._C.static_scaled_fp8_quant.default,
+ ]
+ else:
+ return []
@multi_gpu_test(num_gpus=2)
-@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
+@pytest.mark.parametrize(
+ "test_model_cls, custom_ops",
+ [
+ (TestAllReduceRMSNormModel, "+rms_norm"),
+ (TestAllReduceRMSNormModel, "-rms_norm"),
+ (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
+ (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
+ (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
+ (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
+ ],
+)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
+@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module],
+ custom_ops: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
+ dynamic: bool,
):
num_processes = 2
@@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
args=(
num_processes,
test_model_cls,
+ custom_ops,
batch_size,
seq_len,
hidden_size,
dtype,
enable_fusion,
+ dynamic,
),
nprocs=nprocs,
)
@@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: type[torch.nn.Module],
+ custom_ops: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
+ dynamic: bool,
):
current_platform.seed_everything(0)
@@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
+ custom_ops_list = custom_ops.split(",") if custom_ops else []
compilation_config = CompilationConfig(
+ splitting_ops=[], # avoid automatic rms_norm enablement
+ cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
+ custom_ops=custom_ops_list,
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
- )
+ ),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
@@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
- func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
@@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
passes_for_backend.append(cleanup_pass)
- backend_no_func = TestBackend(*passes_for_backend)
- backend_func = TestBackend(*passes_for_backend, func_pass)
+ backend = TestBackend(*passes_for_backend)
- model = test_model_cls(hidden_size, hidden_size * 2)
+ model = test_model_cls(hidden_size)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
- residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
- compiled_model_no_func = torch.compile(model, backend=backend_no_func)
- compiled_model_no_func(hidden_states, residual)
- compiled_model_func = torch.compile(model, backend=backend_func)
- compiled_model_func(hidden_states, residual)
+ if dynamic:
+ torch._dynamo.mark_dynamic(hidden_states, 0)
- assert sequence_parallelism_pass.matched_count == 1
+ compiled_model = torch.compile(model, backend=backend)
+ compiled_model(hidden_states)
+
+ assert sequence_parallelism_pass.matched_count == 4
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
- backend_no_func.check_before_ops(model.ops_in_model_before())
+ for op in model.ops_in_model_before():
+ assert backend.op_count(op, before=True) == 4
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
- backend_no_func.check_after_ops(model.ops_in_model_after())
+ for op in model.ops_in_model_after():
+ assert backend.op_count(op, before=False) == 4
- # check if the functionalization pass is applied
for op in model.ops_in_model():
- find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
- assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
-
- # make sure the ops were all de-functionalized
- found = dict()
- for node in backend_func.graph_post_pass.nodes:
- for op in model.ops_in_model():
- if is_func(node, op):
- found[op] = True
- assert all(found[op] for op in model.ops_in_model())
+ find_auto_fn(backend.graph_post_pass.nodes, op)
diff --git a/tests/distributed/test_multiproc_executor.py b/tests/distributed/test_multiproc_executor.py
new file mode 100644
index 0000000000000..e741a79bc4ed9
--- /dev/null
+++ b/tests/distributed/test_multiproc_executor.py
@@ -0,0 +1,437 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+Integration tests for MultiprocExecutor at the executor level.
+This test directly tests the executor without going through the LLM interface,
+focusing on executor initialization, RPC calls, and distributed execution.
+"""
+
+import multiprocessing
+import os
+
+from tests.utils import multi_gpu_test
+from vllm.config import VllmConfig
+from vllm.engine.arg_utils import EngineArgs
+from vllm.utils import get_open_port
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.executor.multiproc_executor import MultiprocExecutor
+
+MODEL = "facebook/opt-125m"
+
+
+def create_vllm_config(
+ tensor_parallel_size: int = 1,
+ pipeline_parallel_size: int = 1,
+ max_model_len: int = 256,
+ gpu_memory_utilization: float = 0.3,
+ distributed_executor_backend: str = "mp",
+ nnodes: int = 1,
+ node_rank: int = 0,
+ master_port: int = 0,
+) -> VllmConfig:
+ """Create a VllmConfig for testing using EngineArgs."""
+ engine_args = EngineArgs(
+ model=MODEL,
+ tensor_parallel_size=tensor_parallel_size,
+ pipeline_parallel_size=pipeline_parallel_size,
+ max_model_len=max_model_len,
+ gpu_memory_utilization=gpu_memory_utilization,
+ distributed_executor_backend=distributed_executor_backend,
+ enforce_eager=True,
+ )
+ vllm_config = engine_args.create_engine_config()
+
+ # Override distributed node settings if needed
+ if nnodes > 1 or node_rank > 0:
+ vllm_config.parallel_config.nnodes = nnodes
+ vllm_config.parallel_config.node_rank = node_rank
+ vllm_config.parallel_config.master_port = master_port
+ if nnodes > 1:
+ vllm_config.parallel_config.disable_custom_all_reduce = True
+
+ return vllm_config
+
+
+def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
+ """Create a minimal SchedulerOutput for testing."""
+ # This is a simplified version - in practice you'd need proper
+ # SchedulerOutput construction based on the actual vLLM v1 API
+ return SchedulerOutput(
+ scheduled_new_reqs=[],
+ scheduled_resumed_reqs=[],
+ scheduled_running_reqs=[],
+ num_scheduled_tokens={},
+ total_num_scheduled_tokens=0,
+ )
+
+
+def test_multiproc_executor_initialization():
+ """Test that MultiprocExecutor can be initialized with proper config."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ )
+
+ # Create executor - this should initialize workers
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ # Verify executor properties
+ assert executor.world_size == 1, "World size should be 1 for single GPU"
+ assert executor.local_world_size == 1, "Local world size should be 1"
+ assert hasattr(executor, "workers"), "Executor should have workers"
+ assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
+
+ # Clean up
+ executor.shutdown()
+
+
+@multi_gpu_test(num_gpus=2)
+def test_multiproc_executor_initialization_tensor_parallel():
+ """Test MultiprocExecutor initialization with tensor parallelism."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=1,
+ )
+
+ # Create executor
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ # Verify executor properties
+ assert executor.world_size == 2, "World size should be 2 for TP=2"
+ assert executor.local_world_size == 2, "Local world size should be 2"
+ assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
+
+ # Verify output rank calculation
+ output_rank = executor._get_output_rank()
+ assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
+
+ # Clean up
+ executor.shutdown()
+
+
+@multi_gpu_test(num_gpus=2)
+def test_multiproc_executor_collective_rpc():
+ """Test collective RPC calls to all workers."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=1,
+ )
+
+ # Create executor
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Test check_health RPC - should work without errors
+ executor.check_health()
+
+ # Test that RPC works correctly
+ # Note: We're just testing that the RPC mechanism works,
+ # not testing actual model execution here
+ assert not executor.is_failed, "Executor should not be in failed state"
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+
+def test_multiproc_executor_failure_callback():
+ """Test failure callback registration and invocation."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Test callback registration
+ callback_invoked = []
+
+ def test_callback():
+ callback_invoked.append(True)
+
+ # Register callback
+ executor.register_failure_callback(test_callback)
+
+ # Callback should not be invoked yet
+ assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
+
+ # Simulate failure
+ executor.is_failed = True
+
+ # Register another callback - should be invoked immediately
+ executor.register_failure_callback(test_callback)
+ assert len(callback_invoked) == 1, (
+ "Callback should be invoked when executor is failed"
+ )
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+
+@multi_gpu_test(num_gpus=2)
+def test_multiproc_executor_worker_monitor():
+ """Test that worker monitor is set up correctly."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=1,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Verify all worker processes are alive
+ for worker in executor.workers:
+ assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
+
+ # Verify executor is not in failed state
+ assert not executor.is_failed, "Executor should not be in failed state"
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+ # After shutdown, workers should be terminated
+ import time
+
+ time.sleep(0.5) # Give processes time to terminate
+ for worker in executor.workers:
+ assert not worker.proc.is_alive(), (
+ f"Worker rank {worker.rank} should terminate after shutdown"
+ )
+
+
+@multi_gpu_test(num_gpus=2)
+def test_multiproc_executor_get_response_message_queues():
+ """Test message queue retrieval for different ranks."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=1,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Get all message queues
+ all_queues = executor.get_response_mqs()
+ assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
+
+ # Get message queue for specific rank
+ rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
+ assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
+
+ rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
+ assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+
+def test_multiproc_executor_shutdown_cleanup():
+ """Test that shutdown properly cleans up resources."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ # Verify executor is set up
+ assert hasattr(executor, "workers"), "Executor should have workers"
+ assert len(executor.workers) > 0, "Should have at least one worker"
+
+ # Shutdown
+ executor.shutdown()
+
+ # Verify cleanup
+ import time
+
+ time.sleep(0.5) # Give processes time to terminate
+
+ for worker in executor.workers:
+ assert not worker.proc.is_alive(), "Worker processes should be terminated"
+
+ # Verify shutdown event is set
+ assert executor.shutdown_event.is_set(), "Shutdown event should be set"
+
+ # Multiple shutdowns should be safe (idempotent)
+ executor.shutdown()
+ executor.shutdown()
+
+
+@multi_gpu_test(num_gpus=4)
+def test_multiproc_executor_pipeline_parallel():
+ """Test MultiprocExecutor with pipeline parallelism."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=2,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Verify executor properties
+ assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
+ assert len(executor.workers) == 4, "Should have 4 workers"
+
+ # Verify output rank calculation
+ # For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
+ # Specifically rank 2 (first rank of last PP stage)
+ output_rank = executor._get_output_rank()
+ assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
+
+ # Verify max_concurrent_batches for pipeline parallel
+ assert executor.max_concurrent_batches == 2, (
+ "Max concurrent batches should equal PP size"
+ )
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+
+def test_multiproc_executor_properties():
+ """Test various executor properties and configurations."""
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=1,
+ pipeline_parallel_size=1,
+ )
+
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ try:
+ # Test supports_pp property
+ assert MultiprocExecutor.supports_pp is True, (
+ "MultiprocExecutor should support pipeline parallelism"
+ )
+
+ # Test world_size calculation
+ assert executor.world_size == (
+ executor.parallel_config.tensor_parallel_size
+ * executor.parallel_config.pipeline_parallel_size
+ ), "World size should equal TP * PP"
+
+ # Test local_world_size calculation
+ assert executor.local_world_size == (
+ executor.parallel_config.world_size // executor.parallel_config.nnodes
+ ), "Local world size should be world_size / nnodes"
+
+ finally:
+ # Clean up
+ executor.shutdown()
+
+
+@multi_gpu_test(num_gpus=4)
+def test_multiproc_executor_multi_node():
+ """
+ Test MultiprocExecutor with multi-node configuration.
+ This simulates 2 nodes with TP=4:
+ - Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
+ - Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
+ Total world_size = 4, nnodes = 2
+ """
+ port = get_open_port()
+ # symm_mem does not work for simulating multi instance in single node
+ os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
+
+ def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
+ """Run a single node's executor."""
+ executor = None
+ try:
+ # Set CUDA_VISIBLE_DEVICES for this node
+ if node_rank == 0:
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
+ else:
+ os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
+
+ # Create config for this node
+ vllm_config = create_vllm_config(
+ tensor_parallel_size=4, # Total TP across all nodes
+ pipeline_parallel_size=1,
+ nnodes=2, # 2 nodes
+ node_rank=node_rank,
+ master_port=port, # same port
+ )
+
+ # Create executor for this node
+ executor = MultiprocExecutor(vllm_config=vllm_config)
+
+ # Verify node-specific properties
+ assert executor.world_size == 4, (
+ f"World size should be 4 on node {node_rank}"
+ )
+ assert executor.local_world_size == 2, (
+ f"Local world size should be 2 on node {node_rank}"
+ )
+ assert len(executor.workers) == 2, (
+ f"Should have 2 local workers on node {node_rank}"
+ )
+
+ # Verify worker ranks are correct for this node
+ expected_ranks = [node_rank * 2, node_rank * 2 + 1]
+ actual_ranks = sorted([w.rank for w in executor.workers])
+ assert actual_ranks == expected_ranks, (
+ f"Node {node_rank} should have workers "
+ f"with ranks {expected_ranks}, got {actual_ranks}"
+ )
+ # Verify all workers are alive
+ for worker in executor.workers:
+ assert worker.proc.is_alive(), (
+ f"Worker rank {worker.rank} should be alive on node {node_rank}"
+ )
+ # executor.gen
+ # Put success result in queue BEFORE shutdown to avoid hanging
+ result_queue.put({"node": node_rank, "success": True})
+ import time
+
+ time.sleep(2)
+ executor.shutdown()
+ except Exception as e:
+ # Put failure result in queue
+ result_queue.put({"node": node_rank, "success": False, "error": str(e)})
+ raise e
+ finally:
+ if executor is not None:
+ executor.shutdown()
+
+ # Create a queue to collect results from both processes
+ result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
+
+ # Start both node processes
+ processes = []
+ for node_rank in range(2):
+ p = multiprocessing.Process(
+ target=run_node,
+ args=(node_rank, result_queue, port),
+ name=f"Node{node_rank}",
+ )
+ p.start()
+ processes.append(p)
+
+ # Wait for both processes to complete
+ all_completed = True
+ for p in processes:
+ p.join(timeout=60)
+ if p.is_alive():
+ p.terminate()
+ p.join(timeout=20)
+ if p.is_alive():
+ p.kill()
+ p.join()
+ all_completed = False
+
+ # Check results from both nodes
+ results: list[dict[str, int | bool]] = []
+ while len(results) < 2:
+ try:
+ result = result_queue.get(timeout=1)
+ results.append(result)
+ except Exception:
+ pass
+ assert all_completed, "Not all processes completed successfully"
+ assert len(results) == 2, f"Expected 2 results, got {len(results)}"
+ assert results[0]["success"], f"Node 0 failed: {results[0]}"
+ assert results[1]["success"], f"Node 1 failed: {results[1]}"
diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py
index 94b2b51211a64..f38c509775ed5 100644
--- a/tests/distributed/test_sequence_parallel.py
+++ b/tests/distributed/test_sequence_parallel.py
@@ -18,6 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
+from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS
@@ -161,6 +162,7 @@ def _compare_sp(
test_options: SPTestOptions,
num_gpus_available: int,
use_inductor_graph_partition: bool,
+ enable_async_tp: bool,
*,
method: Literal["generate", "encode"],
is_multimodal: bool,
@@ -244,10 +246,10 @@ def _compare_sp(
compilation_config = {
"mode": CompilationMode.VLLM_COMPILE,
- "custom_ops": ["+rms_norm"],
"compile_sizes": [4, 8],
"pass_config": {
"enable_sequence_parallelism": True,
+ "enable_async_tp": enable_async_tp,
"enable_fusion": enable_fusion,
"enable_noop": True,
},
@@ -307,6 +309,7 @@ SP_TEST_MODELS = [
],
)
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
+@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
@create_new_process_for_each_test()
def test_tp_sp_generation(
model_id: str,
@@ -316,10 +319,19 @@ def test_tp_sp_generation(
test_options: SPTestOptions,
num_gpus_available,
use_inductor_graph_partition: bool,
+ enable_async_tp: bool,
):
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
+ # Skip FP8 SP-only test on sm89 (compute capability 8.9)
+ if (
+ "fp8" in model_id.lower()
+ and current_platform.get_device_capability() < (9, 0)
+ and (not enable_async_tp)
+ ):
+ pytest.skip("FP8 reduction support begins with sm90 capable devices.")
+
_compare_sp(
model_id,
parallel_setup,
@@ -328,6 +340,7 @@ def test_tp_sp_generation(
test_options,
num_gpus_available,
use_inductor_graph_partition,
+ enable_async_tp=enable_async_tp,
method="generate",
is_multimodal=False,
)
diff --git a/tests/entrypoints/openai/test_enable_force_include_usage.py b/tests/entrypoints/openai/test_enable_force_include_usage.py
index 3ddf2308eb1d5..9d527c45c1fae 100644
--- a/tests/entrypoints/openai/test_enable_force_include_usage.py
+++ b/tests/entrypoints/openai/test_enable_force_include_usage.py
@@ -17,7 +17,7 @@ def chat_server_with_force_include_usage(request): # noqa: F811
"128",
"--enforce-eager",
"--max-num-seqs",
- "1",
+ "4",
"--enable-force-include-usage",
"--port",
"55857",
@@ -78,7 +78,7 @@ def transcription_server_with_force_include_usage():
"--dtype",
"bfloat16",
"--max-num-seqs",
- "1",
+ "4",
"--enforce-eager",
"--enable-force-include-usage",
"--gpu-memory-utilization",
diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py
index dbcec9d31fc9b..4e7b765d7713f 100644
--- a/tests/entrypoints/openai/test_metrics.py
+++ b/tests/entrypoints/openai/test_metrics.py
@@ -16,6 +16,7 @@ from transformers import AutoTokenizer
from vllm import version
+from ...conftest import LocalAssetServer
from ...utils import RemoteOpenAIServer
MODELS = {
@@ -69,7 +70,6 @@ async def client(server):
_PROMPT = "Hello my name is Robert and I love magic"
-_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
@@ -250,6 +250,7 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
@pytest.mark.asyncio
async def test_metrics_exist(
+ local_asset_server: LocalAssetServer,
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
@@ -265,13 +266,21 @@ async def test_metrics_exist(
temperature=0.0,
)
else:
+ # https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
await client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
- {"type": "image_url", "image_url": {"url": _IMAGE_URL}},
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": local_asset_server.url_for(
+ "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ ),
+ },
+ },
{"type": "text", "text": "What's in this image?"},
],
}
diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py
index 2a7df08ea3b0e..d83c6726e72da 100644
--- a/tests/entrypoints/openai/test_vision.py
+++ b/tests/entrypoints/openai/test_vision.py
@@ -17,10 +17,10 @@ MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_ASSETS = [
- "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
- "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
- "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
- "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+ "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ "Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
+ "1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
+ "RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
]
EXPECTED_MM_BEAM_SEARCH_RES = [
diff --git a/tests/entrypoints/pooling/openai/test_vision_embedding.py b/tests/entrypoints/pooling/openai/test_vision_embedding.py
index 944392d66fa5f..1befb5a3cf7a8 100644
--- a/tests/entrypoints/pooling/openai/test_vision_embedding.py
+++ b/tests/entrypoints/pooling/openai/test_vision_embedding.py
@@ -19,10 +19,10 @@ assert vlm2vec_jinja_path.exists()
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_ASSETS = [
- "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
- "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
- "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
- "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+ "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ "Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
+ "1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
+ "RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
]
diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py
index 9b084f2f660b2..c7662223e1ca5 100644
--- a/tests/kernels/attention/test_merge_attn_states.py
+++ b/tests/kernels/attention/test_merge_attn_states.py
@@ -150,8 +150,8 @@ def test_merge_attn_states(
output_torch = output.clone()
output_lse_torch = output_lse.clone()
total_time_torch_kernel = 0
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
# 0. Run the Torch kernel
prefix_lse_torch = prefix_lse.clone()
@@ -188,8 +188,8 @@ def test_merge_attn_states(
output_lse_ref_triton = output_lse.clone()
total_time_triton_kernel = 0
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(
diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py
index 62704bbcbbc79..2285709fa7d60 100644
--- a/tests/kernels/moe/test_batched_moe.py
+++ b/tests/kernels/moe/test_batched_moe.py
@@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
@dataclass
diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py
index cd34617ee0fc4..88db4b3e537c2 100644
--- a/tests/kernels/moe/test_block_fp8.py
+++ b/tests/kernels/moe/test_block_fp8.py
@@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py
index 3799e60f1294a..e35ca4caa9dbc 100644
--- a/tests/kernels/moe/test_block_int8.py
+++ b/tests/kernels/moe/test_block_int8.py
@@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.bfloat16]
diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py
index 5512ccce47b05..c15837f145705 100644
--- a/tests/kernels/moe/test_cutlass_moe.py
+++ b/tests/kernels/moe/test_cutlass_moe.py
@@ -42,8 +42,6 @@ MNK_FACTORS = [
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass
diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py
index 0faf8bc95d2ec..455ecacef5ec3 100644
--- a/tests/kernels/moe/test_deepep_deepgemm_moe.py
+++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py
@@ -7,6 +7,7 @@ fp8 block-quantized case.
"""
import dataclasses
+from contextlib import contextmanager
import pytest
import torch.distributed
@@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
+from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
@@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
P = ParamSpec("P")
+@contextmanager
+def with_dp_metadata(M: int, world_size: int):
+ num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
+
+ vllm_config = VllmConfig()
+ vllm_config.parallel_config.data_parallel_size = world_size
+ vllm_config.parallel_config.enable_expert_parallel = True
+
+ with set_forward_context(
+ None,
+ vllm_config,
+ num_tokens=M,
+ num_tokens_across_dp=num_tokens_across_dp,
+ ):
+ yield
+
+
def next_power_of_2(x):
import math
@@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
quant_config=quant_config,
)
- out = mk.forward(
- hidden_states=test_tensors.rank_tokens,
- w1=w1,
- w2=w2,
- topk_weights=test_tensors.topk_weights,
- topk_ids=test_tensors.topk,
- inplace=False,
- activation="silu",
- global_num_experts=num_experts,
- expert_map=build_expert_map(),
- apply_router_weight_on_input=False,
- )
+ with with_dp_metadata(
+ M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
+ ):
+ out = mk.forward(
+ hidden_states=test_tensors.rank_tokens,
+ w1=w1,
+ w2=w2,
+ topk_weights=test_tensors.topk_weights,
+ topk_ids=test_tensors.topk,
+ inplace=False,
+ activation="silu",
+ global_num_experts=num_experts,
+ expert_map=build_expert_map(),
+ apply_router_weight_on_input=False,
+ )
return out
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index 707068b2bbdc2..218df4a2632c3 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -45,8 +45,6 @@ MNK_FACTORS = [
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a):
@@ -79,10 +77,14 @@ class TestData:
@staticmethod
def make_moe_tensors_8bit(
- m: int, k: int, n: int, e: int, reorder: bool
+ m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
) -> "TestData":
+ is_gated = activation != "relu2_no_mul"
+
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
- w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
+ w13 = torch.randn(
+ (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
+ )
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
@@ -192,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
+@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
+ activation: str,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
- td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
+ td = TestData.make_moe_tensors_8bit(
+ m, k, n, e, reorder=False, activation=activation
+ )
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
@@ -235,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
- activation="silu",
+ activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
@@ -255,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer,
topk_weights,
topk_ids,
- activation="silu",
+ activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index c27cf2468ede5..0550c2d9e2125 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
]
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
def run_moe_test(
diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py
index a2de64974b353..dd4eb4da913bd 100644
--- a/tests/kernels/moe/test_pplx_cutlass_moe.py
+++ b/tests/kernels/moe/test_pplx_cutlass_moe.py
@@ -192,8 +192,6 @@ def pplx_cutlass_moe(
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
def _pplx_moe(
diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py
index 0f0ed3326d159..f671b23d300ce 100644
--- a/tests/kernels/moe/test_pplx_moe.py
+++ b/tests/kernels/moe/test_pplx_moe.py
@@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
def torch_prepare(
diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
index 933cd9dbdeaa0..7a467e160b784 100644
--- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py
@@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py
index 55f092e7ea694..e9973c1fcc15e 100644
--- a/tests/kernels/quantization/test_block_fp8.py
+++ b/tests/kernels/quantization/test_block_fp8.py
@@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py
index dabc10a122f7a..310091b6a554d 100644
--- a/tests/kernels/quantization/test_block_int8.py
+++ b/tests/kernels/quantization/test_block_int8.py
@@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
-vllm_config.scheduler_config.max_num_seqs = 128
-vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222]
diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py
new file mode 100644
index 0000000000000..1ab75933ee31e
--- /dev/null
+++ b/tests/model_executor/test_eagle_quantization.py
@@ -0,0 +1,169 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from unittest.mock import Mock, patch
+
+import pytest
+import torch
+
+from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig
+from vllm.model_executor.models.utils import get_draft_quant_config
+from vllm.platforms import current_platform
+
+DEVICES = (
+ [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
+ if current_platform.is_cuda_alike()
+ else ["cpu"]
+)
+
+
+def test_get_draft_quant_config_with_draft_model():
+ mock_draft_model_config = Mock(spec=ModelConfig)
+ mock_load_config = Mock(spec=LoadConfig)
+ mock_speculative_config = Mock(spec=SpeculativeConfig)
+ mock_speculative_config.draft_model_config = mock_draft_model_config
+
+ mock_vllm_config = Mock(spec=VllmConfig)
+ mock_vllm_config.speculative_config = mock_speculative_config
+ mock_vllm_config.load_config = mock_load_config
+
+ mock_quant_config = Mock()
+ with patch.object(
+ VllmConfig, "get_quantization_config", return_value=mock_quant_config
+ ):
+ result = get_draft_quant_config(mock_vllm_config)
+
+ # Verify the function calls get_quantization_config with draft model config
+ VllmConfig.get_quantization_config.assert_called_once_with(
+ mock_draft_model_config, mock_load_config
+ )
+ assert result == mock_quant_config
+
+
+def test_get_draft_quant_config_without_draft_model():
+ mock_speculative_config = Mock(spec=SpeculativeConfig)
+ mock_speculative_config.draft_model_config = None
+
+ mock_vllm_config = Mock(spec=VllmConfig)
+ mock_vllm_config.speculative_config = mock_speculative_config
+ mock_vllm_config.load_config = Mock(spec=LoadConfig)
+
+ result = get_draft_quant_config(mock_vllm_config)
+
+ assert result is None
+
+
+@torch.inference_mode()
+@pytest.mark.parametrize("device", DEVICES)
+def test_fc_layer_quant_config_usage(dist_init, device) -> None:
+ import torch
+
+ from vllm.model_executor.layers.linear import ReplicatedLinear
+
+ if current_platform.is_cuda_alike():
+ torch.cuda.set_device(device)
+
+ torch.set_default_device(device)
+
+ input_size = 256
+ output_size = 128
+
+ fc_no_quant = ReplicatedLinear(
+ input_size=input_size,
+ output_size=output_size,
+ bias=False,
+ params_dtype=torch.float16,
+ quant_config=None,
+ prefix="fc",
+ )
+
+ assert fc_no_quant.quant_config is None
+ assert fc_no_quant.input_size == input_size
+ assert fc_no_quant.output_size == output_size
+
+ mock_quant_config = Mock()
+ fc_with_quant = ReplicatedLinear(
+ input_size=input_size,
+ output_size=output_size,
+ bias=False,
+ params_dtype=torch.float16,
+ quant_config=mock_quant_config,
+ prefix="fc",
+ )
+
+ assert fc_with_quant.quant_config == mock_quant_config
+
+ # Check forward pass
+ x = torch.randn(2, input_size, dtype=torch.float16)
+ output, _ = fc_no_quant(x)
+ assert output.shape == (2, output_size)
+
+
+def test_kv_cache_scale_name_handling():
+ # Mock a quant config that supports cache scales
+ mock_quant_config = Mock()
+ mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
+
+ # Condition check in load_weights
+ name = "layers.0.self_attn.k_proj.weight"
+ scale_name = mock_quant_config.get_cache_scale(name)
+
+ # Check if get_cache_scale is called and returns expected value
+ mock_quant_config.get_cache_scale.assert_called_once_with(name)
+ assert scale_name == "layers.0.self_attn.kv_scale"
+
+
+def test_kv_cache_scale_name_no_scale():
+ # Mock a quant config that returns None for get_cache_scale
+ mock_quant_config = Mock()
+ mock_quant_config.get_cache_scale = Mock(return_value=None)
+
+ name = "layers.0.mlp.gate_proj.weight"
+ scale_name = mock_quant_config.get_cache_scale(name)
+
+ # Should return None for weights that don't have cache scales
+ assert scale_name is None
+
+
+def test_maybe_remap_kv_scale_name():
+ from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
+
+ params_dict = {
+ "layers.0.self_attn.kv_scale": Mock(),
+ "layers.1.self_attn.kv_scale": Mock(),
+ }
+
+ name = "layers.0.self_attn.some_scale"
+ remapped = maybe_remap_kv_scale_name(name, params_dict)
+
+ assert remapped in params_dict or remapped == name or remapped is None
+
+
+def test_load_weights_kv_scale_handling():
+ kv_scale_param = Mock()
+ kv_scale_param.weight_loader = Mock()
+
+ params_dict = {
+ "layers.0.self_attn.kv_scale": kv_scale_param,
+ }
+
+ mock_quant_config = Mock()
+ mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
+
+ # Load_weights logic for KV cache scales
+ name = "layers.0.self_attn.k_proj.weight"
+ loaded_weight_tensor = torch.tensor([1.0, 2.0])
+
+ if mock_quant_config is not None:
+ scale_name = mock_quant_config.get_cache_scale(name)
+ if scale_name:
+ param = params_dict[scale_name]
+ assert param is kv_scale_param
+ weight_to_load = (
+ loaded_weight_tensor
+ if loaded_weight_tensor.dim() == 0
+ else loaded_weight_tensor[0]
+ )
+
+ assert scale_name == "layers.0.self_attn.kv_scale"
+ assert weight_to_load == loaded_weight_tensor[0]
diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py
index f8e3fa7d1560f..0d41b93233d5a 100644
--- a/tests/models/language/pooling/test_extract_hidden_states.py
+++ b/tests/models/language/pooling/test_extract_hidden_states.py
@@ -11,7 +11,7 @@ from vllm import TokensPrompt
["Qwen/Qwen3-0.6B"],
)
@torch.inference_mode
-def test_embed_models(hf_runner, vllm_runner, model: str):
+def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
@@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
- enable_prefix_caching=False,
+ enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
@@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
+ assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0
+
+ # test enable_prefix_caching plus all pooling
+ # we need to skip reading cache at this request by
+ # request.skip_reading_prefix_cache
+ pooling_outputs = vllm_model.llm.encode(
+ [TokensPrompt(prompt_token_ids=t) for t in token_prompts],
+ pooling_task="token_embed",
+ )
+
+ for n, output in zip(n_prompt_tokens, pooling_outputs):
+ assert len(output.prompt_token_ids) == n
+ assert len(output.outputs.data) == n
+ assert output.num_cached_tokens == 0
+
+ # skip_reading_prefix_cache can still write to cache
+ # to accelerate following requests
+ pooling_outputs = vllm_model.llm.encode(
+ [TokensPrompt(prompt_token_ids=t) for t in token_prompts],
+ pooling_task="embed",
+ )
+
+ for n, output in zip(n_prompt_tokens, pooling_outputs):
+ assert len(output.prompt_token_ids) == n
+ assert output.num_cached_tokens > 0
diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py
index 2482452645ef1..a31a771238e26 100644
--- a/tests/models/language/pooling/test_mm_classifier_conversion.py
+++ b/tests/models/language/pooling/test_mm_classifier_conversion.py
@@ -75,7 +75,7 @@ def test_gemma_multimodal(
{
"type": "image_url",
"image_url": {
- "url": "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
+ "url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/red_chair.jpg"
},
},
{"type": "text", "text": "A fine 19th century piece of furniture."},
diff --git a/tests/models/multimodal/generation/test_multimodal_gguf.py b/tests/models/multimodal/generation/test_multimodal_gguf.py
new file mode 100644
index 0000000000000..e596b20c6302b
--- /dev/null
+++ b/tests/models/multimodal/generation/test_multimodal_gguf.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Literal, NamedTuple
+
+import pytest
+from huggingface_hub import hf_hub_download
+from pytest import MarkDecorator
+
+from tests.quantization.utils import is_quant_method_supported
+from vllm.assets.image import ImageAsset
+from vllm.utils.torch_utils import set_default_torch_num_threads
+
+from ....conftest import PromptImageInput, VllmRunner
+from ...utils import check_logprobs_close
+
+
+class GGUFMMTestConfig(NamedTuple):
+ original_model: str
+ gguf_repo: str
+ gguf_backbone: str
+ gguf_mmproj: str
+ prompt: list[str]
+ mm_data: dict[Literal["images"], PromptImageInput]
+ max_model_len: int = 4096
+ marks: list[MarkDecorator] = []
+
+ @property
+ def gguf_model(self):
+ hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj)
+ return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
+
+
+GEMMA3_CONFIG = GGUFMMTestConfig(
+ original_model="google/gemma-3-4b-it",
+ gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
+ gguf_backbone="gemma-3-4b-it-q4_0.gguf",
+ gguf_mmproj="mmproj-model-f16-4B.gguf",
+ prompt=["Describe this image in detail:"],
+ mm_data={"images": [ImageAsset("stop_sign").pil_image]},
+ marks=[pytest.mark.core_model],
+)
+
+MODELS_TO_TEST = [GEMMA3_CONFIG]
+
+
+def run_multimodal_gguf_test(
+ vllm_runner: type[VllmRunner],
+ model: GGUFMMTestConfig,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+):
+ # Run gguf model.
+ with (
+ set_default_torch_num_threads(1),
+ vllm_runner(
+ model_name=model.gguf_model,
+ enforce_eager=True,
+ tokenizer_name=model.original_model,
+ dtype=dtype,
+ max_model_len=model.max_model_len,
+ ) as gguf_model,
+ ):
+ gguf_outputs = gguf_model.generate_greedy_logprobs(
+ prompts=model.prompt,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ **model.mm_data,
+ )
+
+ # Run unquantized model.
+ with vllm_runner(
+ model_name=model.original_model,
+ enforce_eager=True, # faster tests
+ dtype=dtype,
+ max_model_len=model.max_model_len,
+ ) as original_model:
+ original_outputs = original_model.generate_greedy_logprobs(
+ prompts=model.prompt,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ **model.mm_data,
+ )
+
+ check_logprobs_close(
+ outputs_0_lst=original_outputs,
+ outputs_1_lst=gguf_outputs,
+ name_0="original",
+ name_1="gguf",
+ )
+
+
+@pytest.mark.skipif(
+ not is_quant_method_supported("gguf"),
+ reason="gguf is not supported on this GPU type.",
+)
+@pytest.mark.parametrize(
+ "model",
+ [
+ pytest.param(test_config, marks=test_config.marks)
+ for test_config in MODELS_TO_TEST
+ ],
+)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_tokens", [32])
+@pytest.mark.parametrize("num_logprobs", [10])
+def test_models(
+ vllm_runner: type[VllmRunner],
+ model: GGUFMMTestConfig,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+) -> None:
+ run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py
index 24220978534ca..dc4b4546e451b 100644
--- a/tests/models/quantization/test_bitsandbytes.py
+++ b/tests/models/quantization/test_bitsandbytes.py
@@ -14,10 +14,13 @@ from vllm.platforms import current_platform
from ...utils import compare_two_settings, multi_gpu_test
from ..utils import check_embeddings_close, check_logprobs_close
-pytestmark = pytest.mark.skipif(
- current_platform.is_rocm(),
- reason="bitsandbytes quantization not supported on ROCm (CUDA-only kernels)",
-)
+if current_platform.is_rocm():
+ from vllm.platforms.rocm import on_gfx9
+
+ pytestmark = pytest.mark.skipif(
+ on_gfx9(),
+ reason="bitsandbytes not supported on gfx9 (warp size 64 limitation)",
+ )
models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py
index 5e2438857aeef..3b9597507ac1b 100644
--- a/tests/models/quantization/test_gguf.py
+++ b/tests/models/quantization/test_gguf.py
@@ -78,6 +78,12 @@ DOLPHIN_CONFIG = GGUFTestConfig(
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)
+GEMMA3_CONFIG = GGUFTestConfig(
+ original_model="google/gemma-3-270m-it",
+ gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF",
+ gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf",
+)
+
MODELS = [
# LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458
QWEN2_CONFIG,
@@ -85,6 +91,7 @@ MODELS = [
GPT2_CONFIG,
STABLELM_CONFIG,
DOLPHIN_CONFIG,
+ GEMMA3_CONFIG,
# STARCODER_CONFIG, # broken
]
@@ -148,7 +155,7 @@ def check_model_outputs(
"model",
[pytest.param(test_config, marks=test_config.marks) for test_config in MODELS],
)
-@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
diff --git a/tests/models/registry.py b/tests/models/registry.py
index b8b9dc9c43799..b33f3ab2b5a11 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -173,6 +173,10 @@ class _HfExamplesInfo:
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
+ "AfmoeForCausalLM": _HfExamplesInfo(
+ "arcee-ai/Trinity-Nano",
+ is_available_online=False,
+ ),
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py
index ea795fcbbde55..639e290406fe2 100644
--- a/tests/multimodal/test_utils.py
+++ b/tests/multimodal/test_utils.py
@@ -16,10 +16,10 @@ from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_ASSETS = [
- "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
- "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
- "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
- "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+ "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ "Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
+ "1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
+ "RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
]
TEST_VIDEO_URLS = [
diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py
index e7d902ed26aaa..31b65189b5ec3 100644
--- a/tests/quantization/test_compressed_tensors.py
+++ b/tests/quantization/test_compressed_tensors.py
@@ -141,7 +141,7 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
],
)
-@pytest.mark.parametrize("max_tokens", [8])
+@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.parametrize(
"use_aiter", [True, False] if current_platform.is_rocm() else [False]
@@ -182,7 +182,7 @@ def test_compressed_tensors_w8a8_logprobs(
example_prompts, max_tokens, num_logprobs
)
- with vllm_runner(model_path, dtype=dtype) as vllm_model:
+ with vllm_runner(model_path, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs
)
diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py
index a3fb4a6953474..1591ce1c4f5ad 100644
--- a/tests/quantization/test_cpu_offload.py
+++ b/tests/quantization/test_cpu_offload.py
@@ -19,8 +19,8 @@ def test_cpu_offload_fp8():
# Test loading a quantized checkpoint
compare_two_settings(
"neuralmagic/Qwen2-1.5B-Instruct-FP8",
- [],
- ["--cpu-offload-gb", "1"],
+ ["--enforce_eager"],
+ ["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480,
)
@@ -35,8 +35,8 @@ def test_cpu_offload_gptq(monkeypatch):
# Test GPTQ Marlin
compare_two_settings(
"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
- [],
- ["--cpu-offload-gb", "1"],
+ ["--enforce_eager"],
+ ["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480,
)
@@ -51,8 +51,8 @@ def test_cpu_offload_awq(monkeypatch):
# Test AWQ Marlin
compare_two_settings(
"Qwen/Qwen2-1.5B-Instruct-AWQ",
- [],
- ["--cpu-offload-gb", "1"],
+ ["--enforce_eager"],
+ ["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480,
)
@@ -67,7 +67,7 @@ def test_cpu_offload_compressed_tensors(monkeypatch):
# Test wNa16
compare_two_settings(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2",
- [],
- ["--cpu-offload-gb", "1"],
+ ["--enforce_eager"],
+ ["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480,
)
diff --git a/tests/quantization/test_cpu_wna16.py b/tests/quantization/test_cpu_wna16.py
new file mode 100644
index 0000000000000..077b802e559dc
--- /dev/null
+++ b/tests/quantization/test_cpu_wna16.py
@@ -0,0 +1,23 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
+from vllm.platforms import current_platform
+
+if not current_platform.is_cpu():
+ pytest.skip("skipping CPU-only tests", allow_module_level=True)
+
+MODELS = [
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ",
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx
+]
+DTYPE = ["bfloat16"]
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", DTYPE)
+def test_ipex_quant(vllm_runner, model, dtype):
+ with vllm_runner(model, dtype=dtype) as llm:
+ output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
+ assert output
+ print(output)
diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py
index 2a72f734e431b..b992e976ac308 100644
--- a/tests/quantization/test_experts_int8.py
+++ b/tests/quantization/test_experts_int8.py
@@ -21,7 +21,7 @@ MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"]
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
-@pytest.mark.parametrize("max_tokens", [10])
+@pytest.mark.parametrize("max_tokens", [4])
def test_model_experts_int8_startup(
hf_runner,
vllm_runner,
@@ -33,5 +33,7 @@ def test_model_experts_int8_startup(
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
- with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model:
+ with vllm_runner(
+ model, dtype=dtype, enforce_eager=True, quantization="experts_int8"
+ ) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py
index f02da2996ffea..7bcac9ad768e7 100644
--- a/tests/quantization/test_fp8.py
+++ b/tests/quantization/test_fp8.py
@@ -45,10 +45,10 @@ def test_model_load_and_run(
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
- with vllm_runner(model_id) as llm:
+ with vllm_runner(model_id, enforce_eager=True) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
- outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
+ outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@@ -85,7 +85,7 @@ def test_kv_cache_model_load_and_run(
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
- with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
+ with vllm_runner(model_id, kv_cache_dtype="fp8", enforce_eager=True) as llm:
def check_model(model):
attn = model.model.layers[0].self_attn.attn
@@ -112,7 +112,7 @@ def test_kv_cache_model_load_and_run(
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
- outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
+ outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@@ -142,7 +142,10 @@ def test_load_fp16_model(
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
with vllm_runner(
- "facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype
+ "facebook/opt-125m",
+ quantization="fp8",
+ enforce_eager=True,
+ kv_cache_dtype=kv_cache_dtype,
) as llm:
def check_model(model):
diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py
index ae9b1df3377dc..4f3c52df6c283 100644
--- a/tests/quantization/test_ipex_quant.py
+++ b/tests/quantization/test_ipex_quant.py
@@ -26,7 +26,7 @@ DTYPE = ["bfloat16"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype):
- with vllm_runner(model, dtype=dtype) as llm:
- output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
+ with vllm_runner(model, dtype=dtype, enforce_eager=True) as llm:
+ output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
assert output
print(output)
diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py
index f009a4cfb870d..d92dfaa2cc7b5 100644
--- a/tests/quantization/test_lm_head.py
+++ b/tests/quantization/test_lm_head.py
@@ -49,4 +49,4 @@ def test_lm_head(
vllm_model.apply_model(check_model)
- print(vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1])
+ print(vllm_model.generate_greedy(["Hello my name is"], max_tokens=4)[0][1])
diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py
index 8abf65d29784d..0298994c396f6 100644
--- a/tests/quantization/test_modelopt.py
+++ b/tests/quantization/test_modelopt.py
@@ -88,6 +88,6 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
llm.apply_model(check_model)
# Run a simple generation test to ensure the model works
- output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
+ output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
assert output
print(f"ModelOpt FP8 output: {output}")
diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py
index e8ea4148585bf..61efd2ce66c71 100644
--- a/tests/quantization/test_ptpc_fp8.py
+++ b/tests/quantization/test_ptpc_fp8.py
@@ -38,6 +38,7 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
"facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
+ enforce_eager=True,
kv_cache_dtype=kv_cache_dtype,
)
except AssertionError as e:
@@ -65,5 +66,5 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
llm.apply_model(check_model)
- output = llm.generate_greedy("Hello my name is", max_tokens=20)
+ output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output
diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py
index 8da048703df93..a09856c78559a 100644
--- a/tests/quantization/test_register_quantization_config.py
+++ b/tests/quantization/test_register_quantization_config.py
@@ -23,8 +23,8 @@ from vllm.model_executor.layers.quantization import (
get_quantization_config,
register_quantization_config,
)
-from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
- QuantizationConfig,
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig, # noqa: E501
)
@@ -142,5 +142,5 @@ def test_custom_quant(vllm_runner, model, monkeypatch):
llm.apply_model(check_model)
- output = llm.generate_greedy("Hello my name is", max_tokens=20)
+ output = llm.generate_greedy("Hello my name is", max_tokens=1)
assert output
diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py
index 82413f36e997f..fb8d6130c3779 100644
--- a/tests/quantization/test_torchao.py
+++ b/tests/quantization/test_torchao.py
@@ -392,7 +392,7 @@ def test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant(
assert not has_int4_preshuffled_tensor
assert weight_attrs == [False, 1, 0, True]
- output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
+ output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
assert output
diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py
index c9d227599cde5..ea40c48027205 100644
--- a/tests/samplers/test_logprobs.py
+++ b/tests/samplers/test_logprobs.py
@@ -24,9 +24,7 @@ def test_ranks(
greedy,
flat_logprobs,
example_prompts,
- monkeypatch: pytest.MonkeyPatch,
):
- monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@@ -36,6 +34,7 @@ def test_ranks(
max_tokens=MAX_TOKENS,
logprobs=NUM_TOP_LOGPROBS,
prompt_logprobs=NUM_PROMPT_LOGPROBS,
+ flat_logprobs=flat_logprobs,
)
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
diff --git a/tests/test_inputs.py b/tests/test_inputs.py
index 50a273016ab80..b1fb4e06a6906 100644
--- a/tests/test_inputs.py
+++ b/tests/test_inputs.py
@@ -86,34 +86,6 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
assert zipped["mm_processor_kwargs"] == exp_kwargs
-@pytest.mark.parametrize(
- "model_id",
- [
- "facebook/opt-125m",
- ],
-)
-@pytest.mark.parametrize(
- "prompt",
- [
- {
- "prompt": "",
- "multi_modal_data": {"dummy": []},
- },
- {
- "prompt_token_ids": [],
- "multi_modal_data": {"dummy": []},
- },
- ],
-)
-def test_preprocessor_text_no_mm_inputs(model_id, prompt):
- model_config = ModelConfig(model=model_id)
- tokenizer = init_tokenizer_from_configs(model_config)
- input_preprocessor = InputPreprocessor(model_config, tokenizer)
-
- with pytest.raises(ValueError, match="does not support multimodal inputs"):
- input_preprocessor.preprocess(prompt)
-
-
@pytest.mark.parametrize(
"model_id",
[
@@ -127,6 +99,13 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
{"prompt_token_ids": []},
],
)
+@pytest.mark.skip(
+ reason=(
+ "Applying huggingface processor on text inputs results in "
+ "significant performance regression for multimodal models. "
+ "See https://github.com/vllm-project/vllm/issues/26320"
+ )
+)
def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = init_tokenizer_from_configs(model_config)
diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py
index d26a460d2bcab..75e9d337aa24e 100644
--- a/tests/test_logprobs.py
+++ b/tests/test_logprobs.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import pytest
-
from vllm.logprobs import (
FlatLogprobs,
Logprob,
@@ -14,24 +12,20 @@ from vllm.logprobs import (
)
-def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
- monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
-
- prompt_logprobs = create_prompt_logprobs()
+def test_create_logprobs_non_flat() -> None:
+ prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
assert isinstance(prompt_logprobs, list)
# Ensure first prompt position logprobs is None
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] is None
- sample_logprobs = create_sample_logprobs()
+ sample_logprobs = create_sample_logprobs(flat_logprobs=False)
assert isinstance(sample_logprobs, list)
assert len(sample_logprobs) == 0
-def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
- monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
-
- prompt_logprobs = create_prompt_logprobs()
+def test_create_logprobs_flat() -> None:
+ prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
@@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] == dict()
- sample_logprobs = create_sample_logprobs()
+ sample_logprobs = create_sample_logprobs(flat_logprobs=True)
assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
@@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0
-def test_append_logprobs_for_next_position_none_flat(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
- logprobs = create_sample_logprobs()
+def test_append_logprobs_for_next_position_none_flat() -> None:
+ logprobs = create_sample_logprobs(flat_logprobs=False)
append_logprobs_for_next_position(
logprobs,
token_ids=[1],
@@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
]
-def test_append_logprobs_for_next_position_flat(
- monkeypatch: pytest.MonkeyPatch,
-) -> None:
- monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
- logprobs = create_sample_logprobs()
+def test_append_logprobs_for_next_position_flat() -> None:
+ logprobs = create_sample_logprobs(flat_logprobs=True)
append_logprobs_for_next_position(
logprobs,
token_ids=[1],
diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py
index c358589dbc292..3a48b5206141d 100644
--- a/tests/tool_use/test_kimi_k2_tool_parser.py
+++ b/tests/tool_use/test_kimi_k2_tool_parser.py
@@ -60,6 +60,11 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
ids=[
"tool_call_with_content_before",
"multi_tool_call_with_content_before",
+ "concatenated_tool_calls_bug_fix",
+ "three_concatenated_tool_calls",
+ "mixed_spacing_tool_calls",
+ "angle_brackets_in_json",
+ "newlines_in_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@@ -114,6 +119,123 @@ functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool
],
"I'll help you check the weather. ",
),
+ (
+ """I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"latitude": 34.0522, "longitude": -118.2437}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"content": "Los Angeles today"}<|tool_call_end|><|tool_calls_section_end|>""",
+ [
+ ToolCall(
+ id="functions.get_weather:0",
+ function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps(
+ {"latitude": 34.0522, "longitude": -118.2437}
+ ),
+ ),
+ type="function",
+ ),
+ ToolCall(
+ id="functions.get_news:1",
+ function=FunctionCall(
+ name="get_news",
+ arguments=json.dumps({"content": "Los Angeles today"}),
+ ),
+ type="function",
+ ),
+ ],
+ "I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. ",
+ ),
+ (
+ """I'll help you with multiple tasks. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"topic": "technology"}<|tool_call_end|><|tool_call_begin|>functions.send_email:2<|tool_call_argument_begin|>{"to": "user@example.com", "subject": "Daily Update"}<|tool_call_end|><|tool_calls_section_end|>""",
+ [
+ ToolCall(
+ id="functions.get_weather:0",
+ function=FunctionCall(
+ name="get_weather",
+ arguments=json.dumps({"city": "New York"}),
+ ),
+ type="function",
+ ),
+ ToolCall(
+ id="functions.get_news:1",
+ function=FunctionCall(
+ name="get_news",
+ arguments=json.dumps({"topic": "technology"}),
+ ),
+ type="function",
+ ),
+ ToolCall(
+ id="functions.send_email:2",
+ function=FunctionCall(
+ name="send_email",
+ arguments=json.dumps(
+ {"to": "user@example.com", "subject": "Daily Update"}
+ ),
+ ),
+ type="function",
+ ),
+ ],
+ "I'll help you with multiple tasks. ",
+ ),
+ (
+ """Mixed spacing test. <|tool_calls_section_begin|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {} <|tool_call_end|><|tool_call_begin|>functions.test2:1<|tool_call_argument_begin|>{}<|tool_call_end|> <|tool_calls_section_end|>""",
+ [
+ ToolCall(
+ id="functions.test:0",
+ function=FunctionCall(
+ name="test",
+ arguments=json.dumps({}),
+ ),
+ type="function",
+ ),
+ ToolCall(
+ id="functions.test2:1",
+ function=FunctionCall(
+ name="test2",
+ arguments=json.dumps({}),
+ ),
+ type="function",
+ ),
+ ],
+ "Mixed spacing test. ",
+ ),
+ (
+ """I need to process HTML content. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_html:0<|tool_call_argument_begin|>{"html": "content
", "text": "normal text"}<|tool_call_end|><|tool_calls_section_end|>""",
+ [
+ ToolCall(
+ id="functions.process_html:0",
+ function=FunctionCall(
+ name="process_html",
+ arguments=json.dumps(
+ {"html": "content
", "text": "normal text"}
+ ),
+ ),
+ type="function",
+ )
+ ],
+ "I need to process HTML content. ",
+ ),
+ (
+ """I need to process formatted JSON. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_data:0<|tool_call_argument_begin|>{
+ "name": "test",
+ "value": 123,
+ "nested": {
+ "key": "value"
+ }
+}<|tool_call_end|><|tool_calls_section_end|>""",
+ [
+ ToolCall(
+ id="functions.process_data:0",
+ function=FunctionCall(
+ name="process_data",
+ arguments=json.dumps(
+ {"name": "test", "value": 123, "nested": {"key": "value"}},
+ indent=2,
+ ),
+ ),
+ type="function",
+ )
+ ],
+ "I need to process formatted JSON. ",
+ ),
],
)
def test_extract_tool_calls(
@@ -209,3 +331,596 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser):
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
+
+
+def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
+ """
+ Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
+ is suppressed and does not leak into reasoning_delta.
+ This is the main vulnerability being fixed.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ # Get token IDs for the markers
+ section_begin_token_id = kimi_k2_tool_parser.vocab.get(
+ "<|tool_calls_section_begin|>"
+ )
+ tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
+
+ # Simulate streaming sequence:
+ # Delta 1: "I'll help you with that. "
+ result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="I'll help you with that. ",
+ delta_text="I'll help you with that. ",
+ previous_token_ids=[],
+ current_token_ids=[1, 2, 3], # Regular tokens
+ delta_token_ids=[1, 2, 3],
+ request=None,
+ )
+ assert result1 is not None
+ assert result1.content == "I'll help you with that. "
+
+ # Delta 2: "<|tool_calls_section_begin|>"
+ prev_ids = [1, 2, 3]
+ curr_ids = prev_ids + [section_begin_token_id]
+ result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="I'll help you with that. ",
+ current_text="I'll help you with that. <|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=prev_ids,
+ current_token_ids=curr_ids,
+ delta_token_ids=[section_begin_token_id],
+ request=None,
+ )
+ # Section marker should be stripped and suppressed
+ assert result2 is None or (result2.content is None or result2.content == "")
+
+ # Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
+ prev_ids = curr_ids
+ curr_ids = curr_ids + [4, 5]
+ result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="I'll help you with that. <|tool_calls_section_begin|>",
+ current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
+ delta_text=" spurious text ",
+ previous_token_ids=prev_ids,
+ current_token_ids=curr_ids,
+ delta_token_ids=[4, 5],
+ request=None,
+ )
+ # CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
+ assert result3 is None or (result3.content is None or result3.content == "")
+
+ # Delta 4: "<|tool_call_begin|>..."
+ prev_ids = curr_ids
+ curr_ids = curr_ids + [tool_call_begin_token_id]
+ _result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
+ current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
+ delta_text="<|tool_call_begin|>",
+ previous_token_ids=prev_ids,
+ current_token_ids=curr_ids,
+ delta_token_ids=[tool_call_begin_token_id],
+ request=None,
+ )
+ # Now we're in tool call mode, result depends on internal state
+ # The key is that the spurious text from Delta 3 was not leaked
+
+
+def test_split_markers_across_deltas(kimi_k2_tool_parser):
+ """
+ Test that markers split across delta chunks are correctly detected
+ via the rolling buffer mechanism.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_token_id = kimi_k2_tool_parser.vocab.get(
+ "<|tool_calls_section_begin|>"
+ )
+
+ # Delta 1: "...reasoning<|tool_calls_sec"
+ _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Some reasoning",
+ current_text="Some reasoning<|tool_calls_sec",
+ delta_text="<|tool_calls_sec",
+ previous_token_ids=[1, 2],
+ current_token_ids=[1, 2, 3], # Partial token
+ delta_token_ids=[3],
+ request=None,
+ )
+ # Partial token not recognized yet, might be buffered
+ # Should return as content or None (depends on implementation)
+
+ # Delta 2: "tion_begin|> " (completes the marker)
+ _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Some reasoning<|tool_calls_sec",
+ current_text="Some reasoning<|tool_calls_section_begin|> ",
+ delta_text="tion_begin|> ",
+ previous_token_ids=[1, 2, 3],
+ current_token_ids=[1, 2, section_begin_token_id, 4],
+ delta_token_ids=[section_begin_token_id, 4],
+ request=None,
+ )
+ # Now the complete marker should be detected via buffer
+ # The parser should enter tool section mode
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+
+def test_marker_variants(kimi_k2_tool_parser):
+ """Test that both singular and plural marker variants are recognized."""
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ # Test singular variant: <|tool_call_section_begin|> (note: singular "call")
+ singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
+
+ if singular_token_id is not None: # Only test if tokenizer supports it
+ _result = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning ",
+ current_text="Reasoning <|tool_call_section_begin|>",
+ delta_text="<|tool_call_section_begin|>",
+ previous_token_ids=[1, 2],
+ current_token_ids=[1, 2, singular_token_id],
+ delta_token_ids=[singular_token_id],
+ request=None,
+ )
+ # Should enter tool section mode with singular variant too
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+
+def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
+ """
+ Test that after exiting a tool section with <|tool_calls_section_end|>,
+ subsequent text is correctly returned as reasoning content.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
+
+ # Enter tool section
+ _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="<|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=[],
+ current_token_ids=[section_begin_id],
+ delta_token_ids=[section_begin_id],
+ request=None,
+ )
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+ # Exit tool section
+ _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="<|tool_calls_section_begin|>",
+ current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
+ delta_text="<|tool_calls_section_end|>",
+ previous_token_ids=[section_begin_id],
+ current_token_ids=[section_begin_id, section_end_id],
+ delta_token_ids=[section_end_id],
+ request=None,
+ )
+ assert kimi_k2_tool_parser.in_tool_section is False
+
+ # Subsequent reasoning text should be returned normally
+ result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
+ current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
+ delta_text=" More reasoning",
+ previous_token_ids=[section_begin_id, section_end_id],
+ current_token_ids=[section_begin_id, section_end_id, 10, 11],
+ delta_token_ids=[10, 11],
+ request=None,
+ )
+ assert result3 is not None
+ assert result3.content == " More reasoning"
+
+
+def test_empty_tool_section(kimi_k2_tool_parser):
+ """Test an empty tool section (begin immediately followed by end)."""
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
+
+ # Section begin
+ _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning ",
+ current_text="Reasoning <|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=[1],
+ current_token_ids=[1, section_begin_id],
+ delta_token_ids=[section_begin_id],
+ request=None,
+ )
+
+ # Immediate section end
+ _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning <|tool_calls_section_begin|>",
+ current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
+ delta_text="<|tool_calls_section_end|>",
+ previous_token_ids=[1, section_begin_id],
+ current_token_ids=[1, section_begin_id, section_end_id],
+ delta_token_ids=[section_end_id],
+ request=None,
+ )
+ # Should exit cleanly without errors
+ assert kimi_k2_tool_parser.in_tool_section is False
+
+
+def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
+ """
+ Test that the parser recovers from a malformed tool section
+ that never closes properly.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+
+ # Enter tool section
+ _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="<|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=[],
+ current_token_ids=[section_begin_id],
+ delta_token_ids=[section_begin_id],
+ request=None,
+ )
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+ # Simulate a lot of text without proper tool calls or section end
+ # This should trigger the error recovery mechanism
+ large_text = "x" * 10000 # Exceeds max_section_chars
+
+ result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="<|tool_calls_section_begin|>",
+ current_text="<|tool_calls_section_begin|>" + large_text,
+ delta_text=large_text,
+ previous_token_ids=[section_begin_id],
+ current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
+ delta_token_ids=list(range(100, 100 + len(large_text))),
+ request=None,
+ )
+
+ # Parser should have force-exited the tool section
+ assert kimi_k2_tool_parser.in_tool_section is False
+ # And returned the content as reasoning
+ assert result2 is not None
+ assert result2.content == large_text
+
+
+def test_state_reset(kimi_k2_tool_parser):
+ """Test that reset_streaming_state() properly clears all state."""
+ # Put parser in a complex state
+ kimi_k2_tool_parser.in_tool_section = True
+ kimi_k2_tool_parser.token_buffer = "some buffer"
+ kimi_k2_tool_parser.current_tool_id = 5
+ kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
+ kimi_k2_tool_parser.section_char_count = 1000
+
+ # Reset
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ # Verify all state is cleared
+ assert kimi_k2_tool_parser.in_tool_section is False
+ assert kimi_k2_tool_parser.token_buffer == ""
+ assert kimi_k2_tool_parser.current_tool_id == -1
+ assert kimi_k2_tool_parser.prev_tool_call_arr == []
+ assert kimi_k2_tool_parser.section_char_count == 0
+ assert kimi_k2_tool_parser.current_tool_name_sent is False
+ assert kimi_k2_tool_parser.streamed_args_for_tool == []
+
+
+def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
+ """
+ Test that begin→noise→tool_begin within the SAME chunk suppresses
+ the noise text correctly (not just across chunks).
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
+
+ # Single delta containing: section_begin + spurious text + tool_call_begin
+ combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
+
+ result = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning ",
+ current_text="Reasoning " + combined_text,
+ delta_text=combined_text,
+ previous_token_ids=[1, 2],
+ current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
+ delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
+ request=None,
+ )
+
+ # The noise text should NOT leak into content
+ # Result should either be None/empty or start tool call parsing
+ if result is not None and result.content is not None:
+ # If content is returned, it should not contain the noise
+ assert "noise text" not in result.content
+ assert result.content == "" or result.content.strip() == ""
+
+
+def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
+ """
+ Test that if the stream ends (EOF) without a proper section end marker,
+ the parser doesn't leak text, doesn't crash, and resets state cleanly.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+
+ # Enter tool section
+ _result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="<|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=[],
+ current_token_ids=[section_begin_id],
+ delta_token_ids=[section_begin_id],
+ request=None,
+ )
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+ # Some content in tool section
+ result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="<|tool_calls_section_begin|>",
+ current_text="<|tool_calls_section_begin|> partial content",
+ delta_text=" partial content",
+ previous_token_ids=[section_begin_id],
+ current_token_ids=[section_begin_id, 10, 11],
+ delta_token_ids=[10, 11],
+ request=None,
+ )
+ # Content should be suppressed
+ assert result2.content == "" or result2.content is None
+
+ # Stream ends (EOF) - no more deltas, no section_end marker
+ # Simulate this by manually checking state and resetting
+ # (In real usage, the request handler would call reset_streaming_state)
+ assert kimi_k2_tool_parser.in_tool_section is True # Still in section
+
+ # Reset state (as would happen between requests)
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ # Verify clean slate
+ assert kimi_k2_tool_parser.in_tool_section is False
+ assert kimi_k2_tool_parser.token_buffer == ""
+
+ # Next request should work normally
+ result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="New reasoning",
+ delta_text="New reasoning",
+ previous_token_ids=[],
+ current_token_ids=[20, 21],
+ delta_token_ids=[20, 21],
+ request=None,
+ )
+ assert result3 is not None
+ assert result3.content == "New reasoning"
+
+
+def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
+ """
+ CRITICAL TEST: Verify that when both section_begin and section_end
+ markers appear in the SAME chunk, the parser correctly:
+ 1. Enters the tool section
+ 2. Immediately exits the tool section
+ 3. Does NOT get stuck in in_tool_section=True state
+
+ This tests the bug fix where elif was changed to if to handle
+ both state transitions in a single delta.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
+
+ # Single chunk with both markers (e.g., empty tool section)
+ combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
+
+ result = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Some reasoning ",
+ current_text="Some reasoning " + combined_delta,
+ delta_text=combined_delta,
+ previous_token_ids=[1, 2],
+ current_token_ids=[1, 2, section_begin_id, section_end_id],
+ delta_token_ids=[section_begin_id, section_end_id],
+ request=None,
+ )
+
+ # CRITICAL: Parser should NOT be stuck in tool section
+ assert kimi_k2_tool_parser.in_tool_section is False, (
+ "Parser stuck in tool section after processing both begin/end in same chunk. "
+ "This indicates the elif bug was not fixed."
+ )
+
+ # Result should be empty or contain only stripped content
+ assert result is not None
+ assert result.content == "" or result.content is None
+
+ # Verify subsequent content streams correctly (not suppressed)
+ result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Some reasoning " + combined_delta,
+ current_text="Some reasoning " + combined_delta + " More reasoning",
+ delta_text=" More reasoning",
+ previous_token_ids=[1, 2, section_begin_id, section_end_id],
+ current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
+ delta_token_ids=[10, 11],
+ request=None,
+ )
+
+ # This content should NOT be suppressed (we're out of tool section)
+ assert result2 is not None
+ assert result2.content == " More reasoning"
+
+
+def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
+ """
+ Test the same-chunk scenario with actual content between markers.
+ Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
+ all arriving in one delta. The key is that the state machine correctly
+ transitions in and out within the same chunk.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
+
+ # Chunk with begin, some whitespace/noise, and end all together
+ # This simulates a tool section that opens and closes in the same chunk
+ combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
+
+ _result = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning ",
+ current_text="Reasoning " + combined_delta,
+ delta_text=combined_delta,
+ previous_token_ids=[1],
+ current_token_ids=[1, section_begin_id, 100, section_end_id],
+ delta_token_ids=[section_begin_id, 100, section_end_id],
+ request=None,
+ )
+
+ # Parser should exit cleanly (not stuck in tool section)
+ assert kimi_k2_tool_parser.in_tool_section is False
+
+ # Verify the fix: next content should stream normally, not be suppressed
+ result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Reasoning " + combined_delta,
+ current_text="Reasoning " + combined_delta + " Done",
+ delta_text=" Done",
+ previous_token_ids=[1, section_begin_id, 100, section_end_id],
+ current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
+ delta_token_ids=[200],
+ request=None,
+ )
+
+ # Content after section should be returned (not suppressed)
+ assert result2 is not None
+ assert result2.content == " Done"
+
+
+def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
+ """
+ CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
+ <|tool_calls_section_end|> appear in the SAME chunk, the parser:
+ 1. Processes the tool_call_end first (emits final arguments)
+ 2. THEN exits the section
+ 3. Does NOT drop the final tool call update
+ 4. Does NOT leak special tokens into reasoning
+
+ This tests the deferred section exit fix.
+ """
+ kimi_k2_tool_parser.reset_streaming_state()
+
+ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
+ section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
+ tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
+ tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
+
+ # Simulate a streaming sequence for a SHORT tool call (all in one chunk):
+ # 1. Reasoning text
+ result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="Let me help. ",
+ delta_text="Let me help. ",
+ previous_token_ids=[],
+ current_token_ids=[1, 2],
+ delta_token_ids=[1, 2],
+ request=None,
+ )
+ assert result1 is not None
+ assert result1.content == "Let me help. "
+
+ # 2. Section begin
+ _result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text="Let me help. ",
+ current_text="Let me help. <|tool_calls_section_begin|>",
+ delta_text="<|tool_calls_section_begin|>",
+ previous_token_ids=[1, 2],
+ current_token_ids=[1, 2, section_begin_id],
+ delta_token_ids=[section_begin_id],
+ request=None,
+ )
+ assert kimi_k2_tool_parser.in_tool_section is True
+
+ # 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
+ # This is the critical scenario for short tool calls
+ combined = (
+ '<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
+ "<|tool_call_end|><|tool_calls_section_end|>"
+ )
+
+ # Build up the previous text gradually to simulate realistic streaming
+ prev_text = "Let me help. <|tool_calls_section_begin|>"
+ curr_text = prev_text + combined
+
+ result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text=prev_text,
+ current_text=curr_text,
+ delta_text=combined,
+ previous_token_ids=[1, 2, section_begin_id],
+ current_token_ids=[
+ 1,
+ 2,
+ section_begin_id,
+ tool_begin_id,
+ 10,
+ 11,
+ 12,
+ tool_end_id,
+ section_end_id,
+ ],
+ delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
+ request=None,
+ )
+
+ # CRITICAL: Parser should have exited section AFTER processing tool
+ assert kimi_k2_tool_parser.in_tool_section is False
+
+ # Tool call should have been emitted (not dropped)
+ # The result might be the tool name or None depending on state, but
+ # importantly, it shouldn't be returning the literal tokens as content
+
+ if result3 is not None and result3.content is not None:
+ # Verify no special tokens leaked into content
+ assert "<|tool_call_end|>" not in result3.content
+ assert "<|tool_calls_section_end|>" not in result3.content
+
+ # 4. Verify subsequent content streams normally
+ result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
+ previous_text=curr_text,
+ current_text=curr_text + " Done",
+ delta_text=" Done",
+ previous_token_ids=[
+ 1,
+ 2,
+ section_begin_id,
+ tool_begin_id,
+ 10,
+ 11,
+ 12,
+ tool_end_id,
+ section_end_id,
+ ],
+ current_token_ids=[
+ 1,
+ 2,
+ section_begin_id,
+ tool_begin_id,
+ 10,
+ 11,
+ 12,
+ tool_end_id,
+ section_end_id,
+ 20,
+ ],
+ delta_token_ids=[20],
+ request=None,
+ )
+
+ # Content after tool section should stream normally
+ assert result4 is not None
+ assert result4.content == " Done"
diff --git a/tests/utils.py b/tests/utils.py
index c8f18384c5114..c31a2aeeb9c80 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -676,7 +676,7 @@ def compare_all_settings(
results += _test_image_text(
client,
model,
- "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+ "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
)
elif method == "encode":
results += _test_embeddings(client, model, prompt)
diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py
index b4805be802723..ba0b703302e38 100644
--- a/tests/v1/core/test_priority_scheduler_random.py
+++ b/tests/v1/core/test_priority_scheduler_random.py
@@ -3,6 +3,7 @@
import random
import uuid
+import numpy as np
import pytest
from vllm.config import VllmConfig
@@ -99,8 +100,7 @@ def _mock_execute_model(
random.randint(*num_output_tokens_range) for _ in range(len(request_ids))
]
sampled_token_ids = [
- [random.randint(0, 100) for _ in range(num_tokens)]
- for num_tokens in num_output_tokens
+ np.random.randint(0, 100, size=num_tokens) for num_tokens in num_output_tokens
]
return ModelRunnerOutput(
@@ -196,6 +196,8 @@ def test_priority_scheduling_blast(
num_blocks: int,
):
random.seed(42)
+ np.random.seed(42)
+
seen_request_prompt_length = dict[str, int]()
seen_request_ids = set[str]()
seen_mm_hashes = set[str]()
diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
index 6d95c29ec1ab4..0570c0854c678 100644
--- a/tests/v1/core/test_scheduler.py
+++ b/tests/v1/core/test_scheduler.py
@@ -2636,7 +2636,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
model_output = ModelRunnerOutput(
req_ids=[request1.request_id],
req_id_to_index={request1.request_id: 0},
- sampled_token_ids=[[100]],
+ sampled_token_ids=[np.array([100])],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -2842,7 +2842,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
- sampled_token_ids=[[1000]] * len(req_ids),
+ sampled_token_ids=[np.array([1000])] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@@ -2955,7 +2955,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
- sampled_token_ids=[[100]],
+ sampled_token_ids=[np.array([100])],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3006,7 +3006,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[[100] for _ in requests],
+ sampled_token_ids=[np.array([100]) for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3041,7 +3041,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
- sampled_token_ids=[[100], [100, 200]],
+ sampled_token_ids=[np.array([100]), np.array([100, 200])],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@@ -3227,7 +3227,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
model_output = ModelRunnerOutput(
req_ids=[request1.request_id, request2.request_id],
req_id_to_index={request1.request_id: 0, request2.request_id: 1},
- sampled_token_ids=[[100], [121]],
+ sampled_token_ids=[np.array([100]), np.array([121])],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
diff --git a/tests/v1/determinism/conftest.py b/tests/v1/determinism/conftest.py
new file mode 100644
index 0000000000000..3c2136e005849
--- /dev/null
+++ b/tests/v1/determinism/conftest.py
@@ -0,0 +1,11 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
+ """Automatically enable batch invariant kernel overrides for all tests."""
+ monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
+ yield
diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py
similarity index 92%
rename from tests/v1/generation/test_batch_invariance.py
rename to tests/v1/determinism/test_batch_invariance.py
index 8fd038bca5d0f..f018ee551dbfe 100644
--- a/tests/v1/generation/test_batch_invariance.py
+++ b/tests/v1/determinism/test_batch_invariance.py
@@ -6,66 +6,9 @@ import random
import pytest
import torch
+from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from vllm import LLM, SamplingParams
-from vllm.platforms import current_platform
-
-skip_unsupported = pytest.mark.skipif(
- not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
- reason="Requires CUDA and >= Hopper (SM90)",
-)
-
-
-@pytest.fixture(autouse=True)
-def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
- """Automatically enable batch invariant kernel overrides for all tests."""
- monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
- yield
-
-
-def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
- # Generate more realistic prompts that will actually produce varied tokens
- # Use a mix of common English text patterns
-
- prompt_templates = [
- # Question-answer style
- "Question: What is the capital of France?\nAnswer: The capital of France is",
- "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
- "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
- # Story/narrative style
- "Once upon a time in a distant galaxy, there lived",
- "The old man walked slowly down the street, remembering",
- "In the year 2157, humanity finally discovered",
- # Technical/code style
- "To implement a binary search tree in Python, first we need to",
- "The algorithm works by iterating through the array and",
- "Here's how to optimize database queries using indexing:",
- # Factual/informative style
- "The Renaissance was a period in European history that",
- "Climate change is caused by several factors including",
- "The human brain contains approximately 86 billion neurons which",
- # Conversational style
- "I've been thinking about getting a new laptop because",
- "Yesterday I went to the store and bought",
- "My favorite thing about summer is definitely",
- ]
-
- # Pick a random template
- base_prompt = random.choice(prompt_templates)
-
- if max_words < min_words:
- max_words = min_words
- target_words = random.randint(min_words, max_words)
-
- if target_words > 50:
- # For longer prompts, repeat context
- padding_text = (
- " This is an interesting topic that deserves more explanation. "
- * (target_words // 50)
- )
- base_prompt = base_prompt + padding_text
-
- return base_prompt
@skip_unsupported
@@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
llm_bsN.shutdown()
-def _extract_step_logprobs(request_output):
- if getattr(request_output, "outputs", None):
- inner = request_output.outputs[0]
- if hasattr(inner, "logprobs") and inner.logprobs is not None:
- t = torch.tensor(
- [
- inner.logprobs[i][tid].logprob
- for i, tid in enumerate(inner.token_ids)
- ],
- dtype=torch.float32,
- )
- return t, inner.token_ids
-
- return None, None
-
-
@skip_unsupported
@pytest.mark.parametrize(
"backend",
diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py
new file mode 100644
index 0000000000000..23f47863dd23f
--- /dev/null
+++ b/tests/v1/determinism/test_online_batch_invariance.py
@@ -0,0 +1,161 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+HTTP-based batch invariance test: send requests to a running
+vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
+
+Environment variables:
+ - VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B / DeepSeek-R1)
+ - VLLM_TP_SIZE: tensor parallelism size (e.g., 4)
+
+"""
+
+import os
+import random
+import sys
+from typing import Any
+
+import openai
+from utils import _random_prompt, skip_unsupported
+
+from tests.utils import RemoteOpenAIServer
+
+
+def _request_completion(
+ client: openai.OpenAI,
+ model: str,
+ prompt: Any,
+ sp: dict[str, Any],
+ max_retries: int = 3,
+ retry_backoff: float = 0.5,
+) -> dict[str, Any] | None:
+ payload: dict[str, Any] = {"model": model, "prompt": prompt}
+ payload.update(sp)
+
+ for attempt in range(max_retries + 1):
+ try:
+ completion = client.completions.create(**payload)
+ # Convert to plain dict so downstream logic can keep using
+ # dict-style access just like with raw HTTP JSON.
+ return completion.model_dump()
+ except Exception as e: # pragma: no cover
+ if attempt < max_retries:
+ import time as _t
+
+ _t.sleep(retry_backoff * (2**attempt))
+ continue
+ sys.stderr.write(f"Error: {e}\n")
+ return None
+ return None
+
+
+def _extract_tokens_and_logprobs(
+ choice: dict[str, Any],
+) -> tuple[list[Any], list[float] | None]:
+ tokens: list[Any] = []
+ token_logprobs: list[float] | None = None
+ lp = choice.get("logprobs")
+ if lp and isinstance(lp, dict):
+ tokens = lp.get("token_ids") or lp.get("tokens") or []
+ token_logprobs = lp.get("token_logprobs", None)
+ return tokens, token_logprobs
+
+
+def _compare_bs1_vs_bsn_single_process(
+ prompts: list[str],
+ sp_kwargs: dict[str, Any],
+ client: openai.OpenAI,
+ model_name: str,
+) -> None:
+ # BS=1
+ bs1_tokens_per_prompt: list[list[Any]] = []
+ bs1_logprobs_per_prompt: list[list[float] | None] = []
+ for p in prompts:
+ resp = _request_completion(client, model_name, p, sp_kwargs)
+ if resp is None or not resp.get("choices"):
+ raise AssertionError("BS=1 empty/failed response")
+ choice = resp["choices"][0]
+ toks, lps = _extract_tokens_and_logprobs(choice)
+ if lps is None:
+ raise AssertionError(
+ "logprobs not returned; ensure server supports 'logprobs'"
+ )
+ bs1_tokens_per_prompt.append(list(toks))
+ bs1_logprobs_per_prompt.append(list(lps))
+
+ # BS=N
+ bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
+ bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
+ resp = _request_completion(client, model_name, prompts, sp_kwargs)
+ if resp is None or not resp.get("choices"):
+ raise AssertionError("BS=N empty/failed batched response")
+ choices = resp.get("choices", [])
+ if len(choices) != len(prompts):
+ raise AssertionError(
+ f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
+ )
+ for idx, choice in enumerate(choices):
+ toks, lps = _extract_tokens_and_logprobs(choice)
+ if lps is None:
+ raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
+ bsN_tokens_per_prompt[idx] = list(toks)
+ bsN_logprobs_per_prompt[idx] = list(lps)
+
+ # compare
+ for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
+ zip(
+ bs1_tokens_per_prompt,
+ bsN_tokens_per_prompt,
+ bs1_logprobs_per_prompt,
+ bsN_logprobs_per_prompt,
+ )
+ ):
+ if tokens_bs1 != tokens_bsN:
+ raise AssertionError(
+ f"Prompt {i} (sampling): Different tokens sampled. "
+ f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
+ )
+ if logprobs_bs1 is None or logprobs_bsN is None:
+ raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
+ if len(logprobs_bs1) != len(logprobs_bsN):
+ raise AssertionError(
+ f"Prompt {i}: Different number of steps: "
+ f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
+ )
+ for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
+ if a != b:
+ diff = abs(a - b)
+ raise AssertionError(
+ f"Prompt {i} Step {t}: Bitwise mismatch "
+ f"(abs diff={diff:.6e}). "
+ f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
+ )
+
+
+@skip_unsupported
+def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
+ random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
+ model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
+ prompts_all = [_random_prompt(10, 50) for _ in range(32)]
+
+ sp_kwargs: dict[str, Any] = {
+ "temperature": 0.6,
+ "top_p": 1.0,
+ "max_tokens": 8,
+ "seed": 42,
+ "logprobs": 5,
+ }
+
+ tp_size = os.getenv("VLLM_TP_SIZE", "1")
+ server_args: list[str] = []
+ if tp_size:
+ server_args += ["-tp", tp_size]
+
+ with RemoteOpenAIServer(model_name, server_args) as server:
+ client = server.get_client()
+ _compare_bs1_vs_bsn_single_process(
+ prompts=prompts_all,
+ sp_kwargs=sp_kwargs,
+ client=client,
+ model_name=model_name,
+ )
diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/determinism/test_rms_norm_batch_invariant.py
similarity index 97%
rename from tests/v1/generation/test_rms_norm_batch_invariant.py
rename to tests/v1/determinism/test_rms_norm_batch_invariant.py
index f79eba58d6ef2..390872519528c 100644
--- a/tests/v1/generation/test_rms_norm_batch_invariant.py
+++ b/tests/v1/determinism/test_rms_norm_batch_invariant.py
@@ -9,15 +9,10 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
import pytest
import torch
+from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.platforms import current_platform
-
-skip_unsupported = pytest.mark.skipif(
- not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
- reason="Requires CUDA and >= Hopper (SM90)",
-)
@skip_unsupported
diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py
new file mode 100644
index 0000000000000..5141837faea04
--- /dev/null
+++ b/tests/v1/determinism/utils.py
@@ -0,0 +1,74 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import random
+
+import pytest
+import torch
+
+from vllm.platforms import current_platform
+
+skip_unsupported = pytest.mark.skipif(
+ not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
+ reason="Requires CUDA and >= Hopper (SM90)",
+)
+
+
+def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
+ # Generate more realistic prompts that will actually produce varied tokens
+ # Use a mix of common English text patterns
+
+ prompt_templates = [
+ # Question-answer style
+ "Question: What is the capital of France?\nAnswer: The capital of France is",
+ "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
+ "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
+ # Story/narrative style
+ "Once upon a time in a distant galaxy, there lived",
+ "The old man walked slowly down the street, remembering",
+ "In the year 2157, humanity finally discovered",
+ # Technical/code style
+ "To implement a binary search tree in Python, first we need to",
+ "The algorithm works by iterating through the array and",
+ "Here's how to optimize database queries using indexing:",
+ # Factual/informative style
+ "The Renaissance was a period in European history that",
+ "Climate change is caused by several factors including",
+ "The human brain contains approximately 86 billion neurons which",
+ # Conversational style
+ "I've been thinking about getting a new laptop because",
+ "Yesterday I went to the store and bought",
+ "My favorite thing about summer is definitely",
+ ]
+
+ # Pick a random template
+ base_prompt = random.choice(prompt_templates)
+
+ if max_words < min_words:
+ max_words = min_words
+ target_words = random.randint(min_words, max_words)
+
+ if target_words > 50:
+ # For longer prompts, repeat context
+ padding_text = (
+ " This is an interesting topic that deserves more explanation. "
+ * (target_words // 50)
+ )
+ base_prompt = base_prompt + padding_text
+
+ return base_prompt
+
+
+def _extract_step_logprobs(request_output):
+ if getattr(request_output, "outputs", None):
+ inner = request_output.outputs[0]
+ if hasattr(inner, "logprobs") and inner.logprobs is not None:
+ t = torch.tensor(
+ [
+ inner.logprobs[i][tid].logprob
+ for i, tid in enumerate(inner.token_ids)
+ ],
+ dtype=torch.float32,
+ )
+ return t, inner.token_ids
+
+ return None, None
diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py
index dbe403ece0514..00d93e1ba0b53 100644
--- a/tests/v1/e2e/test_async_scheduling.py
+++ b/tests/v1/e2e/test_async_scheduling.py
@@ -15,7 +15,7 @@ from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
-MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base"
+MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
first_prompt = (
@@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict(
temperature=0.0, # greedy
- max_tokens=20,
+ max_tokens=23,
+ min_tokens=18,
)
@@ -65,20 +66,13 @@ def test_without_spec_decoding(
(True, "mp", True, None, False),
(True, "uni", True, None, False),
(False, "mp", True, None, True),
- # Async scheduling + preemption + chunked prefill needs to be fixed (WIP)
- # (True, "mp", True, None, True),
- # (True, "uni", True, None, True),
+ (True, "mp", True, None, True),
+ (True, "uni", True, None, True),
]
- run_tests(
- monkeypatch,
- MODEL,
- test_configs,
- test_sampling_params,
- )
+ run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
-@pytest.mark.skip("MTP model too big to run in fp32 in CI")
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
@@ -86,9 +80,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""
spec_config = {
- "method": "mtp",
+ "method": "eagle3",
"num_speculative_tokens": 2,
+ "model": "nm-testing/Llama3_2_1B_speculator.eagle3",
}
+ # Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}
# test_preemption, executor, async_scheduling,
@@ -103,17 +99,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(False, "mp", True, spec_config_short, True),
(True, "uni", True, spec_config, False),
(True, "uni", True, spec_config_short, False),
- # Async scheduling + preemption + chunked prefill needs to be fixed (WIP)
- # (True, "mp", True, spec_config, True),
- # (True, "uni", True, spec_config_short, True),
+ (True, "mp", True, spec_config, True),
+ (True, "uni", True, spec_config_short, True),
]
- run_tests(
- monkeypatch,
- MTP_MODEL,
- test_configs,
- [{}],
- )
+ run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
@dynamo_config.patch(cache_size_limit=16)
@@ -184,16 +174,15 @@ def run_tests(
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
- # because the acceptance rate can vary, we use a looser
- # tolerance here.
assert (
- pytest.approx(test_acceptance_rate, rel=5e-2)
- == base_acceptance_rate
+ test_acceptance_rate > base_acceptance_rate
+ or test_acceptance_rate
+ == pytest.approx(base_acceptance_rate, rel=5e-2)
)
else:
# Currently the reported acceptance rate is expected to be
- # lower when we skip drafting altogether.
- assert test_acceptance_rate > 0.05
+ # lower when we sometimes skip drafting altogether.
+ assert test_acceptance_rate > 0.1
print(
f"PASSED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}"
@@ -222,6 +211,7 @@ def run_test(
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
+ # Force preemptions
dict(num_gpu_blocks_override=32)
if test_preemption
else dict(gpu_memory_utilization=0.9)
@@ -240,6 +230,7 @@ def run_test(
model,
max_model_len=512,
enable_chunked_prefill=test_prefill_chunking,
+ # Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True,
async_scheduling=async_scheduling,
@@ -257,10 +248,7 @@ def run_test(
results.append(
vllm_model.generate(
example_prompts,
- sampling_params=SamplingParams(
- **default_params,
- **override_params,
- ),
+ sampling_params=SamplingParams(**default_params, **override_params),
return_logprobs=True,
)
)
@@ -272,9 +260,7 @@ def run_test(
if test_preemption:
preemptions = _get_count(
- metrics_before,
- metrics_after,
- "vllm:num_preemptions",
+ metrics_before, metrics_after, "vllm:num_preemptions"
)
assert preemptions > 0, "preemption test had no preemptions"
diff --git a/tests/v1/entrypoints/openai/serving_responses/test_image.py b/tests/v1/entrypoints/openai/serving_responses/test_image.py
index 980d83b787e7a..be5693bbf2736 100644
--- a/tests/v1/entrypoints/openai/serving_responses/test_image.py
+++ b/tests/v1/entrypoints/openai/serving_responses/test_image.py
@@ -15,10 +15,10 @@ MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_ASSETS = [
- "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
- "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
- "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
- "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+ "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+ "Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
+ "1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
+ "RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
]
diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
index a9817313cf022..453ccc81eb14a 100755
--- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
@@ -49,11 +49,13 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
+PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
+DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
-SMI_BIN=$(which nvidia-smi || which rocm-smi)
+SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
@@ -89,8 +91,13 @@ get_model_args() {
get_num_gpus() {
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
- else
+ elif [[ "$SMI_BIN" == *"rocm"* ]]; then
echo "$($SMI_BIN -l | grep GPU | wc -l)"
+ else
+ # works for non-cuda platforms,
+ # assuming at least 1 device and
+ # let system to decide which card to use
+ echo "1"
fi
}
@@ -136,6 +143,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
+ --block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@@ -177,6 +185,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
+ --block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index 8e421717fea30..b264e5108c16d 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -11,6 +11,7 @@ import uuid
from collections import defaultdict
from unittest.mock import patch
+import numpy as np
import pytest
import ray
import torch
@@ -407,6 +408,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
+ block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
@@ -652,6 +654,7 @@ class TestNixlHandshake:
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
+ block_size=worker.block_size,
)
with pytest.raises(RuntimeError):
@@ -706,6 +709,7 @@ class TestNixlHandshake:
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
+ block_size=worker.block_size,
)
# We don't check layout for homogeneous TP and MLA for now, as the
@@ -823,7 +827,7 @@ def test_kv_connector_stats_aggregation():
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
- sampled_token_ids=[[123]], # dummy token
+ sampled_token_ids=[np.array([123])], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
@@ -904,7 +908,7 @@ def test_multi_kv_connector_stats_aggregation():
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
- sampled_token_ids=[[123]],
+ sampled_token_ids=[np.array([123])],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
@@ -962,7 +966,7 @@ def test_scheduler_kv_connector_stats_aggregation():
model_output = ModelRunnerOutput(
req_ids=["req_0"],
req_id_to_index={"req_0": 0},
- sampled_token_ids=[[123]],
+ sampled_token_ids=[np.array([123])],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py
index 421da52415559..805b8c86b0804 100644
--- a/tests/v1/spec_decode/test_eagle.py
+++ b/tests/v1/spec_decode/test_eagle.py
@@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
+@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
@@ -335,6 +336,7 @@ def test_load_model(
attn_backend,
pp_size,
use_distinct_embed_tokens,
+ use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
@@ -350,12 +352,13 @@ def test_load_model(
# Setup draft model mock
mock_model = mock.MagicMock()
+ mock_model.model = mock.MagicMock()
+ mock_model.has_own_embed_tokens = use_distinct_embed_tokens
if use_distinct_embed_tokens:
- # Some models can have a different hidden size than the target model,
- # so we test that their embed_tokens doesn't get overwritten
- mock_model.model.embed_tokens.weight.shape = (131072, 2048)
- else:
- mock_model.model.embed_tokens.weight.shape = (131072, 4096)
+ mock_model.model.embed_tokens = mock.MagicMock()
+ mock_model.has_own_lm_head = use_distinct_lm_head
+ if use_distinct_lm_head:
+ mock_model.lm_head = mock.MagicMock()
mock_get_model.return_value = mock_model
@@ -391,15 +394,13 @@ def test_load_model(
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
- target_model.model.embed_tokens.weight.shape = (131072, 4096)
+ target_model.lm_head = mock.MagicMock()
+ target_model.model.embed_tokens = mock.MagicMock()
from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)
- if method == "eagle":
- target_model.lm_head = mock.MagicMock()
-
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
@@ -409,18 +410,18 @@ def test_load_model(
# Verify common interactions
mock_get_model.assert_called_once()
- # Verify that EAGLE models gain the lm head from the target model
- if method == "eagle":
- assert proposer.model.lm_head == target_model.lm_head
+ # Verify that the lm head is set correctly
+ if use_distinct_lm_head:
+ assert proposer.model.lm_head is not target_model.lm_head
+ else:
+ assert proposer.model.lm_head is target_model.lm_head
# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
- assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
+ assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
else:
- # When pp_size is 1 and the draft and target models have
- # embed_tokens of the same shape, they should be shared.
- assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
+ assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py
index 6d59b58e739eb..c5c0491abaf7c 100644
--- a/tests/v1/spec_decode/test_mtp.py
+++ b/tests/v1/spec_decode/test_mtp.py
@@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
+ # MTP does not have its own embed_tokens or lm_head
+ # so it should share them with the target model
+ mock_model.has_own_embed_tokens = False
+ mock_model.has_own_lm_head = False
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py
index 4a20b6b7bb8fb..b8a55c615426e 100644
--- a/tools/install_nixl_from_source_ubuntu.py
+++ b/tools/install_nixl_from_source_ubuntu.py
@@ -95,6 +95,7 @@ def install_system_dependencies():
"meson",
"libtool",
"libtool-bin",
+ "pkg-config",
]
run_command(["apt-get", "update"])
run_command(["apt-get", "install", "-y"] + apt_packages)
@@ -175,6 +176,7 @@ def build_and_install_prerequisites(args):
build_env["LD_LIBRARY_PATH"] = (
f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":")
)
+ build_env["LDFLAGS"] = "-Wl,-rpath,$ORIGIN"
print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True)
temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse")
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 096266c9764e8..66cf6472eee40 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -2702,6 +2702,31 @@ def cpu_attention_with_kv_cache(
)
+def cpu_gemm_wna16(
+ input: torch.Tensor,
+ q_weight: torch.Tensor,
+ scales: torch.Tensor,
+ zeros: torch.Tensor | None,
+ g_idx: torch.Tensor | None,
+ bias: torch.Tensor | None,
+ pack_factor: int,
+ isa_hint: str,
+) -> torch.Tensor:
+ output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype)
+ torch.ops._C.cpu_gemm_wna16(
+ input,
+ q_weight,
+ output,
+ scales,
+ zeros,
+ g_idx,
+ bias,
+ pack_factor,
+ isa_hint,
+ )
+ return output
+
+
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 6962810bdd09f..9e540fd437bfb 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -310,7 +310,8 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
- self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
+ backend_name = self.attn_backend.get_name()
+ self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index 0e9b0fbe2c028..dddb050ec180e 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -49,6 +49,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.gc_utils import freeze_gc_heap
+from vllm.utils.network_utils import join_host_port
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -1333,8 +1334,9 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
- api_url = f"http://{args.host}:{args.port}{args.endpoint}"
- base_url = f"http://{args.host}:{args.port}"
+ host_port = join_host_port(args.host, args.port)
+ api_url = f"http://{host_port}{args.endpoint}"
+ base_url = f"http://{host_port}"
# Headers
headers = None
diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py
index b0cdb08884a3b..11cf0f85c1787 100644
--- a/vllm/compilation/compiler_interface.py
+++ b/vllm/compilation/compiler_interface.py
@@ -299,7 +299,7 @@ class InductorAdaptor(CompilerInterface):
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
if disable_cache:
return
- # redirect the cache directory to a sub-directory
+ # redirect the cache directory to a subdirectory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index e325bca73abb0..11a18c0e6bb78 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -159,7 +159,7 @@ def support_torch_compile(
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
dim to be decorated with `mark_unbacked`. This is useful if we would like to
- enforce that dynamo do not specialize on 0/1 values in the case of dummy input
+ enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
"""
@@ -483,7 +483,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
Context manager to set/unset customized cudagraph partition wrappers.
If we're using Inductor-based graph partitioning, we currently have the
- whole `fx.Graph` before Inductor lowering and and the piecewise
+ whole `fx.Graph` before Inductor lowering and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py
index 0c2210d72ce07..0e8bb2fc97351 100644
--- a/vllm/compilation/pass_manager.py
+++ b/vllm/compilation/pass_manager.py
@@ -18,6 +18,7 @@ if current_platform.is_cuda_alike():
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
+ from .sequence_parallelism import SequenceParallelismPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
@@ -25,7 +26,6 @@ if current_platform.is_cuda():
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
-from .sequence_parallelism import SequenceParallelismPass
logger = init_logger(__name__)
diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py
index 31624a8fdcc0f..bb4dcf12d865d 100644
--- a/vllm/compilation/sequence_parallelism.py
+++ b/vllm/compilation/sequence_parallelism.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import functools
+
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
@@ -10,98 +12,28 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ kFp8StaticTensorSym,
+)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
+from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
+from .noop_elimination import NoOpEliminationPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
-class _RMSNormAndQuantOpHelper:
- """Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
+def get_first_out_wrapper(fn):
+ @functools.wraps(fn)
+ def wrapper(*args):
+ return fn(*args)[0]
- def __init__(
- self,
- epsilon: float,
- dtype: torch.dtype,
- device: str,
- quant_op: torch._ops.OpOverload | None = None,
- **kwargs,
- ):
- self.epsilon = epsilon
- self.dtype = dtype
- self.device = device
- self.quant_op = quant_op
-
- def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
- return torch.ops.higher_order.auto_functionalized(
- torch.ops._C.rms_norm.default,
- result=result_buffer,
- input=input_tensor,
- weight=weight_tensor,
- epsilon=self.epsilon,
- )
-
- def _functional_fused_add_rmsnorm(
- self, input_tensor, residual_tensor, weight_tensor
- ):
- return torch.ops.higher_order.auto_functionalized(
- torch.ops._C.fused_add_rms_norm.default,
- input=input_tensor,
- residual=residual_tensor,
- weight=weight_tensor,
- epsilon=self.epsilon,
- )
-
- def _functional_rmsnorm_then_quant(
- self,
- rmsnorm_result_buffer,
- quant_result_buffer,
- input_tensor,
- weight_tensor,
- scale_tensor,
- ):
- if self.quant_op is None:
- raise RuntimeError(
- "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
- )
- rmsnorm_out_tuple = self._functional_rmsnorm(
- rmsnorm_result_buffer, input_tensor, weight_tensor
- )
- quant_out_tuple = torch.ops.higher_order.auto_functionalized(
- self.quant_op,
- result=quant_result_buffer,
- input=rmsnorm_out_tuple[1],
- scale=scale_tensor,
- )
- return quant_out_tuple
-
- def _functional_fused_add_rmsnorm_then_quant(
- self,
- quant_result_buffer,
- input_tensor,
- residual_tensor,
- weight_tensor,
- scale_tensor,
- ):
- if self.quant_op is None:
- raise RuntimeError(
- "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
- )
- fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
- input_tensor, residual_tensor, weight_tensor
- )
- quant_out_tuple = torch.ops.higher_order.auto_functionalized(
- self.quant_op,
- result=quant_result_buffer,
- input=fused_add_rmsnorm_out_tuple[1],
- scale=scale_tensor,
- )
- return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
+ return wrapper
-class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
+class _SequenceParallelPatternHelper:
"""Helper for sequence parallelism patterns."""
def __init__(
@@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
epsilon: float,
dtype: torch.dtype,
device: str,
- quant_op: torch._ops.OpOverload | None = None,
- **kwargs,
):
- super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
+ self.epsilon = epsilon
+ self.dtype = dtype
+ self.device = device
self.tp_group = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
@@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
+ super().__init__(epsilon, dtype, device)
+ self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
+
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
- permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
- return [input, permute, arg3_1]
+ return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
- permute: torch.Tensor,
arg3_1: torch.Tensor,
):
all_reduce = self._all_reduce(input)
- rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
+ rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
- return rmsnorm[1], all_reduce
+ return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
- permute: torch.Tensor,
arg3_1: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
- rmsnorm_result = torch.empty_like(reduce_scatter)
- rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
-
- all_gather = self._all_gather(rmsnorm[1])
-
+ rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
+ all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
pm.register_replacement(
@@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
+ super().__init__(epsilon, dtype, device)
+ self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
+
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
- rmsnorm = self._functional_fused_add_rmsnorm(
- all_reduce, residual, rms_norm_weights
- )
- return rmsnorm[1], rmsnorm[2]
+ rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
+ return rmsnorm[0], rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
+ # pattern matcher replaces from top-to-bottom,
+ # so residual is still the full size here.
+ # once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
- rmsnorm = self._functional_fused_add_rmsnorm(
- reduce_scatter, residual, rms_norm_weights
- )
- all_gather = self._all_gather(rmsnorm[1])
- return all_gather, rmsnorm[2]
+ residual = residual[0 : reduce_scatter.size(0), ...]
+ rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
+ all_gather = self._all_gather(rmsnorm[0])
+ # shape of residual changes but that's fine,
+ # next node is already slicing it, now becomes a noop
+ return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
-
-
-class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
- def get_inputs(self):
- mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
-
- residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
- rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
-
- return [
- residual,
- mm_1,
- rms_norm_weights,
- ]
-
- def register(self, pm_pass: PatternMatcherPass):
- def pattern(
- residual: torch.Tensor,
- mm_1: torch.Tensor,
- rms_norm_weights: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- all_reduce = self._all_reduce(mm_1)
- rmsnorm = self._functional_fused_add_rmsnorm(
- all_reduce, residual, rms_norm_weights
- )
- return rmsnorm[1]
-
- def replacement(
- residual: torch.Tensor,
- mm_1: torch.Tensor,
- rms_norm_weights: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- reduce_scatter = self._reduce_scatter(mm_1)
- rmsnorm = self._functional_fused_add_rmsnorm(
- reduce_scatter, residual, rms_norm_weights
- )
- normalized = self._all_gather(rmsnorm[1])
- return normalized
-
pm.register_replacement(
- pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
+ get_first_out_wrapper(pattern),
+ get_first_out_wrapper(replacement),
+ self.get_inputs(),
+ pm.fwd_only,
+ pm_pass,
)
@@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
- self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
+ self,
+ epsilon: float,
+ dtype: torch.dtype,
+ device: str,
):
- super().__init__(epsilon, dtype, device, quant_op=op)
+ super().__init__(epsilon, dtype, device)
+ self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
+ self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
- rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
- quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
- return [input, rmsnorm_result, quant_result, weight, scale]
+ return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
- rmsnorm_result: torch.Tensor,
- quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = self._all_reduce(input)
- static_fp8 = self._functional_rmsnorm_then_quant(
- rmsnorm_result, quant_result, all_reduce, weight, scale
- )
- return static_fp8[1], all_reduce
+ rms = self.rmsnorm_matcher(all_reduce, weight)
+ quant, _ = self.quant_matcher(rms, scale)
+ return quant, all_reduce
def replacement(
input: torch.Tensor,
- rmsnorm_result: torch.Tensor,
- quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
reduce_scatter = self._reduce_scatter(input)
-
- rmsnorm_result = torch.empty_like(
- reduce_scatter, dtype=rmsnorm_result.dtype
- )
- quant_result = torch.empty_like(
- rmsnorm_result, # Output of RMSNorm
- dtype=quant_result.dtype,
- )
- static_fp8 = self._functional_rmsnorm_then_quant(
- rmsnorm_result, quant_result, reduce_scatter, weight, scale
- )
- all_gather = self._all_gather(static_fp8[1])
+ rms = self.rmsnorm_matcher(reduce_scatter, weight)
+ quant, _ = self.quant_matcher(rms, scale)
+ all_gather = self._all_gather(quant)
return all_gather, reduce_scatter
@@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
- def __init__(
- self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
- ):
- super().__init__(epsilon, dtype, device, quant_op=op)
+ def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
+ super().__init__(epsilon, dtype, device)
+ self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
+ self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
-
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
- result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
- return [
- result,
- residual,
- mm_1,
- rms_norm_weights,
- scale,
- ]
+ return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
- result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
- static_fp8, rmsnorm_residual_out = (
- self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
- result, all_reduce, residual, rms_norm_weights, scale
- )
+ rms, residual_out = self.rmsnorm_matcher(
+ all_reduce, rms_norm_weights, residual
)
- return static_fp8[1], rmsnorm_residual_out
+ quant, _ = self.quant_matcher(rms, scale)
+ return quant, residual_out
def replacement(
- result: torch.Tensor,
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
+ # pattern matcher replaces from top-to-bottom,
+ # so residual is still the full size here.
+ # add a temporary slice which will become a noop
+ # once the seqpar pattern with the previous rmsnorm is replaced
reduce_scatter = self._reduce_scatter(mm_1)
- quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
- static_fp8, rmsnorm_residual_out = (
- self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
- quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
- )
+ residual = residual[0 : reduce_scatter.size(0), ...]
+ rms, residual_out = self.rmsnorm_matcher(
+ reduce_scatter, rms_norm_weights, residual
)
- all_gather = self._all_gather(static_fp8[1])
- return all_gather, rmsnorm_residual_out
+ quant, _ = self.quant_matcher(rms, scale)
+ all_gather = self._all_gather(quant)
+ # shape of residual changes but that's fine,
+ # next node is already slicing it, now becomes a noop
+ return all_gather, residual_out
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
-
-class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
- def __init__(
- self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
- ):
- super().__init__(epsilon, dtype, device, quant_op=op)
-
- def get_inputs(self):
- mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
-
- residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
- rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
- result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
- scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
-
- return [
- result,
- residual,
- mm_1,
- rms_norm_weights,
- scale,
- ]
-
- def register(self, pm_pass: PatternMatcherPass):
- def pattern(
- result: torch.Tensor,
- residual: torch.Tensor,
- mm_1: torch.Tensor,
- rms_norm_weights: torch.Tensor,
- scale: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- all_reduce = self._all_reduce(mm_1)
- static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
- result, all_reduce, residual, rms_norm_weights, scale
- )
- return static_fp8[1]
-
- def replacement(
- result: torch.Tensor,
- residual: torch.Tensor,
- mm_1: torch.Tensor,
- rms_norm_weights: torch.Tensor,
- scale: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- reduce_scatter = self._reduce_scatter(mm_1)
- quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
- static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
- quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
- )
- normalized = self._all_gather(static_fp8[1])
- return normalized
-
pm.register_replacement(
- pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
+ get_first_out_wrapper(pattern),
+ get_first_out_wrapper(replacement),
+ self.get_inputs(),
+ pm.fwd_only,
+ pm_pass,
)
@@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
significantly reduce communication overhead and improve overall model
performance.
+
+
+ This pass splits up the residual tensor across TP ranks and hence divides its size.
+ Because the pattern matcher starts at the end of the graph, the replacement
+ contains a slice that temporarily conforms the input residual to the correct size.
+ After all patterns have been matched, we use a NoOpEliminationPass to clean up
+ what have now become no-op slices.
+
+ Note that an older version of the pass did not need this as it operated only on
+ custom rms_norm and fused_rms_norm_add custom ops which did not complain about
+ mismatched shapes during replacement. So this approach has the same assumption that
+ correctness is only maintained if all rms_norm operations are split across ranks.
+
+ Correctness-wise, this is approach strictly better than before - before,
+ the graph was incorrect semantically and shape-wise during the pass.
+ With this approach there's only semantic incorrectness during the pass.
+ Both approaches restore a correct graph once all patterns are matched.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
+ # Used to cleanup redundant views created temporarily
+ # to circumvent residual shape change issues
+ self.noop_cleanup = NoOpEliminationPass(config)
+ self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
+
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass"
)
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
- fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
- epsilon, self.model_dtype, self.device, fp8_quant_op
+ epsilon, self.model_dtype, self.device
).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
- epsilon, self.model_dtype, self.device, fp8_quant_op
- ).register(self.patterns)
- LastAllReduceRMSNormStaticFP8Pattern(
- epsilon, self.model_dtype, self.device, fp8_quant_op
+ epsilon, self.model_dtype, self.device
).register(self.patterns)
# Normal RMSNorm patterns
@@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
epsilon, self.model_dtype, self.device
).register(self.patterns)
- LastAllReduceRMSNormPattern(
- epsilon, self.model_dtype, self.device
- ).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable(self, shape: int | None) -> bool:
@@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
+ # Clean up reshape nodes
+ self.noop_cleanup(graph)
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index 10673041aa685..088d0b1af757a 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -18,6 +18,7 @@ from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
+from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
@@ -773,19 +774,8 @@ class CompilationConfig:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
- # pre-compute the mapping from batch size to padded graph size
- self.bs_to_padded_graph_size = [
- 0 for i in range(self.max_cudagraph_capture_size + 1)
- ]
- for end, start in zip(
- self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
- [0] + self.cudagraph_capture_sizes,
- ):
- for bs in range(start, end):
- if bs == start:
- self.bs_to_padded_graph_size[bs] = start
- else:
- self.bs_to_padded_graph_size[bs] = end
+ # May get recomputed in the model runner if adjustment is needed for spec-decode
+ self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when mode is
@@ -922,3 +912,64 @@ class CompilationConfig:
enable_str,
op,
)
+
+ def adjust_cudagraph_sizes_for_spec_decode(
+ self, uniform_decode_query_len: int, tensor_parallel_size: int
+ ):
+ multiple_of = uniform_decode_query_len
+ if tensor_parallel_size > 1:
+ multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
+ if (
+ multiple_of % uniform_decode_query_len != 0
+ or multiple_of % tensor_parallel_size != 0
+ ):
+ raise ValueError(
+ f"Can't determine cudagraph shapes that are both a "
+ f"multiple of {uniform_decode_query_len} "
+ f"(num_speculative_tokens + 1) required by spec-decode "
+ f"and {tensor_parallel_size} (tensor_parallel_size) "
+ f"required by sequence parallelism please adjust "
+ f"num_speculative_tokens or disable sequence parallelism"
+ )
+
+ if not self.cudagraph_capture_sizes or multiple_of <= 1:
+ return
+
+ assert self.max_cudagraph_capture_size is not None
+ rounded_sizes = sorted(
+ set(
+ round_up(size, multiple_of)
+ for size in self.cudagraph_capture_sizes
+ if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
+ )
+ )
+
+ if len(rounded_sizes) == 0:
+ logger.warning(
+ "No valid cudagraph sizes after rounding to multiple of "
+ " num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
+ " or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
+ multiple_of,
+ )
+ return
+
+ self.max_cudagraph_capture_size = rounded_sizes[-1]
+ self.cudagraph_capture_sizes = rounded_sizes
+
+ # Recompute after adjusting the cudagraph sizes
+ self.compute_bs_to_padded_graph_size()
+
+ def compute_bs_to_padded_graph_size(self):
+ # pre-compute the mapping from batch size to padded graph size
+ self.bs_to_padded_graph_size = [
+ 0 for i in range(self.max_cudagraph_capture_size + 1)
+ ]
+ for end, start in zip(
+ self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
+ [0] + self.cudagraph_capture_sizes,
+ ):
+ for bs in range(start, end):
+ if bs == start:
+ self.bs_to_padded_graph_size[bs] = start
+ else:
+ self.bs_to_padded_graph_size[bs] = end
diff --git a/vllm/config/model.py b/vllm/config/model.py
index b3a28af6de389..3e8790a26e0e3 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -33,10 +33,14 @@ from vllm.transformers_utils.config import (
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
+ uses_custom_attention_masks,
uses_mrope,
)
+from vllm.transformers_utils.gguf_utils import (
+ maybe_patch_hf_config_from_gguf,
+)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
-from vllm.transformers_utils.utils import maybe_model_redirect
+from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@@ -450,6 +454,12 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
+ if check_gguf_file(self.model):
+ raise ValueError(
+ "Using a tokenizer is mandatory when loading a GGUF model. "
+ "Please specify the tokenizer path or name using the "
+ "--tokenizer argument."
+ )
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
@@ -508,6 +518,10 @@ class ModelConfig:
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn,
)
+ hf_config = maybe_patch_hf_config_from_gguf(
+ self.model,
+ hf_config,
+ )
self.hf_config = hf_config
if dict_overrides:
@@ -1006,6 +1020,8 @@ class ModelConfig:
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
+ "cpu_gptq",
+ "cpu_awq",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
@@ -1605,6 +1621,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
+ @property
+ def uses_custom_attention_masks(self) -> bool:
+ return uses_custom_attention_masks(self.hf_config)
+
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index 61bcd15e06a84..9a6326d62e82e 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -210,6 +210,18 @@ class ParallelConfig:
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
+ master_addr: str = "127.0.0.1"
+ """distributed master address for multi-node distributed
+ inference when distributed_executor_backend is mp."""
+ master_port: int = 29501
+ """distributed master port for multi-node distributed
+ inference when distributed_executor_backend is mp."""
+ node_rank: int = 0
+ """distributed node rank for multi-node distributed
+ inference when distributed_executor_backend is mp."""
+ nnodes: int = 1
+ """num of nodes for multi-node distributed
+ inference when distributed_executor_backend is mp."""
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
@@ -387,6 +399,23 @@ class ParallelConfig:
and self.data_parallel_size > 1
)
+ @property
+ def node_rank_within_dp(self) -> int:
+ return self.node_rank % self.nnodes_within_dp
+
+ @property
+ def nnodes_within_dp(self) -> int:
+ if self.nnodes == 1:
+ return 1
+ data_parallel_node_size = (
+ self.data_parallel_size // self.data_parallel_size_local
+ )
+ return self.nnodes // data_parallel_node_size
+
+ @property
+ def local_world_size(self) -> int:
+ return self.world_size // self.nnodes_within_dp
+
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
@@ -528,6 +557,8 @@ class ParallelConfig:
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
+ elif current_platform.is_cuda() and self.nnodes > 1:
+ backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
@@ -565,6 +596,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
+ if self.distributed_executor_backend != "mp" and self.nnodes > 1:
+ raise ValueError(
+ "nnodes > 1 can only be set when distributed exectuor backend is mp."
+ )
@property
def use_ray(self) -> bool:
@@ -607,6 +642,11 @@ class ParallelConfig:
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
+ if self.nnodes > 1:
+ self.disable_custom_all_reduce = True
+ logger.debug(
+ "Disabled the custom all-reduce since we are running on multi-node."
+ )
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."
diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py
index 444568994a95b..8194295ffedb6 100644
--- a/vllm/config/scheduler.py
+++ b/vllm/config/scheduler.py
@@ -6,7 +6,7 @@ from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
-from pydantic import Field, field_validator, model_validator
+from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self, deprecated
@@ -48,13 +48,6 @@ class SchedulerConfig:
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
- max_model_len: int = Field(default=8192, ge=1)
- """Maximum length of a sequence (including prompt and generated text).
-
- The default value here is mainly for convenience when testing.
- In real usage, this should duplicate `ModelConfig.max_model_len` via
- `EngineArgs`."""
-
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
@@ -89,6 +82,12 @@ class SchedulerConfig:
is_multimodal_model: bool = False
"""True if the model is multimodal."""
+ max_model_len: InitVar[int] = 8192
+ """Maximum length of a sequence (including prompt and generated text).
+
+ Note: This is stored in the ModelConfig, and is used only here to
+ provide fallbacks and validate other attributes."""
+
is_encoder_decoder: InitVar[bool] = False
"""True if the model is an encoder-decoder model.
@@ -199,7 +198,7 @@ class SchedulerConfig:
return value
return handler(value)
- def __post_init__(self, is_encoder_decoder: bool) -> None:
+ def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
@@ -221,7 +220,7 @@ class SchedulerConfig:
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
- self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
+ self.long_prefill_token_threshold = int(max_model_len * 0.04)
logger.info(
"Concurrent partial prefills enabled with "
@@ -232,6 +231,8 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
+ self.verify_max_model_len(max_model_len)
+
@property
@deprecated(
"`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
@@ -245,15 +246,14 @@ class SchedulerConfig:
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value
- @model_validator(mode="after")
- def _verify_args(self) -> Self:
+ def verify_max_model_len(self, max_model_len: int) -> Self:
if (
- self.max_num_batched_tokens < self.max_model_len
+ self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
- f"smaller than max_model_len ({self.max_model_len}). "
+ f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
@@ -267,12 +267,12 @@ class SchedulerConfig:
f"({self.max_num_seqs})."
)
- if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
+ if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
- self.max_num_seqs * self.max_model_len,
+ self.max_num_seqs * max_model_len,
)
if self.max_num_partial_prefills > 1:
@@ -282,11 +282,11 @@ class SchedulerConfig:
"max_num_partial_prefills > 1."
)
- if self.long_prefill_token_threshold > self.max_model_len:
+ if self.long_prefill_token_threshold > max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
- f"than the max_model_len ({self.max_model_len})."
+ f"than the max_model_len ({max_model_len})."
)
if self.max_long_partial_prefills > self.max_num_partial_prefills:
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index 31cdeabe501d2..13a8632413d91 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -3,7 +3,7 @@
import ast
import hashlib
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
@@ -29,31 +29,25 @@ else:
logger = init_logger(__name__)
-SpeculativeMethod = Literal[
- "ngram",
- "eagle",
- "eagle3",
- "medusa",
- "mlp_speculator",
- "draft_model",
- "deepseek_mtp",
- "ernie_mtp",
- "qwen3_next_mtp",
- "mimo_mtp",
- "longcat_flash_mtp",
- "pangu_ultra_moe_mtp",
- "mtp",
- "suffix",
-]
-MTP_MODEL_TYPES = (
+MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"longcat_flash_mtp",
+ "mtp",
"pangu_ultra_moe_mtp",
-)
+]
+EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
+SpeculativeMethod = Literal[
+ "ngram",
+ "medusa",
+ "mlp_speculator",
+ "draft_model",
+ "suffix",
+ EagleModelTypes,
+]
@config
@@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
- if self.method in MTP_MODEL_TYPES:
+ if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
@@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
- elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
+ elif self.draft_model_config.hf_config.model_type in get_args(
+ MTPModelTypes
+ ):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index 1e6e455210c88..672b004c4aa56 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -14,13 +14,14 @@ from dataclasses import replace
from datetime import datetime
from functools import lru_cache
from pathlib import Path
-from typing import TYPE_CHECKING, Any, TypeVar
+from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
+from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
@@ -374,10 +375,22 @@ class VllmConfig:
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
+ # Currently, async scheduling only support eagle speculative
+ # decoding.
if self.speculative_config is not None:
- raise ValueError(
- "Async scheduling is not yet compatible with speculative decoding."
- )
+ if self.speculative_config.method not in get_args(EagleModelTypes):
+ raise ValueError(
+ "Currently, async scheduling is only supported "
+ "with EAGLE/MTP kind of speculative decoding"
+ )
+ if self.speculative_config.disable_padded_drafter_batch:
+ raise ValueError(
+ "async scheduling for EAGLE/MTP kind of speculative "
+ "decoding is enabled, but disable_padded_drafter_batch=True "
+ "disable_padded_drafter_batch=True is not supported for "
+ "this situation now. please set "
+ "disable_padded_drafter_batch=Fasle"
+ )
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
@@ -445,8 +458,6 @@ class VllmConfig:
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True
- if self.compilation_config.pass_config.enable_sequence_parallelism:
- self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
@@ -483,21 +494,6 @@ class VllmConfig:
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
- elif (
- current_platform.is_cuda()
- and current_platform.is_device_capability(100)
- and self.model_config.max_model_len > 131072
- and not self.model_config.use_mla
- ):
- # Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
- logger.warning_once(
- "NVIDIA Blackwell TRTLLM attention cannot support "
- "max_model_len >= 131072 (found "
- f"{self.model_config.max_model_len}), causing dynamic "
- "dispatching that breaks full cudagraphs. "
- "Overriding cudagraph_mode to PIECEWISE."
- )
- self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
@@ -635,6 +631,32 @@ class VllmConfig:
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()
+ if self.compilation_config.pass_config.enable_sequence_parallelism:
+ # With pipeline parallelism or dynamo partitioning,
+ # native rms norm tracing errors due to incorrect residual shape.
+ # Use custom rms norm to unblock. In the future,
+ # the pass will operate on higher-level IR to avoid the issue.
+ # TODO: https://github.com/vllm-project/vllm/issues/27894
+ is_fullgraph = (
+ self.compilation_config.use_inductor_graph_partition
+ or len(self.compilation_config.splitting_ops) == 0
+ )
+ if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
+ if "-rms_norm" not in self.compilation_config.custom_ops:
+ self.compilation_config.custom_ops.append("+rms_norm")
+ else:
+ regime = (
+ "Dynamo partition"
+ if not is_fullgraph
+ else "pipeline parallelism"
+ )
+ logger.warning_once(
+ "Sequence parallelism not supported with"
+ "native rms_norm when using %s, "
+ "this will likely lead to an error.",
+ regime,
+ )
+
# final check of cudagraph mode after all possible updates
if current_platform.is_cuda_alike():
if (
@@ -929,7 +951,6 @@ class VllmConfig:
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
- self.scheduler_config.max_model_len = max_model_len
def try_verify_and_update_config(self):
if self.model_config is None:
diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py
index 5046cac2e90a7..052df19e34d72 100644
--- a/vllm/distributed/device_communicators/shm_broadcast.py
+++ b/vllm/distributed/device_communicators/shm_broadcast.py
@@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
import torch
@@ -602,13 +602,87 @@ class MessageQueue:
return obj
return self.dequeue()
+ @staticmethod
+ def create_from_process_group_single_reader(
+ pg: ProcessGroup,
+ max_chunk_bytes,
+ max_chunks,
+ reader_rank: int = 0,
+ blocking: bool = False,
+ ) -> tuple["MessageQueue", list[Handle]]:
+ """
+ Creates a MessageQueue for a process group with a single reader.
+
+ This method is designed for scenarios where only one process (the reader)
+ will consume messages, and all other processes are writers. It sets up
+ the shared memory buffer and communication handles accordingly, and
+ gathers the handles from all processes to the reader.
+
+ Args:
+ pg (ProcessGroup): The torch distributed process group.
+ max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
+ max_chunks (int): Maximum number of chunks in the buffer.
+ reader_rank (int, optional): The global rank that will act as the reader.
+ Defaults to 0.
+ blocking (bool, optional): If True, blocks until all processes are ready.
+ Defaults to False.
+
+ Returns:
+ tuple[MessageQueue, list[Handle]]:
+ The MessageQueue instance for the calling process,
+ and a list of handles (only non-empty for the reader process).
+ """
+ local_size = torch.cuda.device_count()
+ rank = dist.get_rank()
+ same_node = rank // local_size == reader_rank // local_size
+ buffer_io = MessageQueue(
+ n_reader=1,
+ n_local_reader=1 if same_node else 0,
+ max_chunk_bytes=max_chunk_bytes,
+ max_chunks=max_chunks,
+ )
+ handle = buffer_io.export_handle()
+ handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
+ dist.gather_object(handle, handles, dst=reader_rank, group=pg)
+ if blocking:
+ buffer_io.wait_until_ready()
+ return buffer_io, cast(list[Handle], handles or [])
+
@staticmethod
def create_from_process_group(
pg: ProcessGroup | StatelessProcessGroup,
max_chunk_bytes,
max_chunks,
- writer_rank=0,
+ writer_rank: int = 0,
+ external_writer_handle=None,
+ blocking: bool = True,
) -> "MessageQueue":
+ """
+ Creates a MessageQueue for a distributed process group with one writer and
+ multiple readers.
+
+ This method is designed for scenarios where one process (the writer) sends
+ messages, and all other processes (the readers) receive messages. It sets up
+ the shared memory buffer and socket communication handles accordingly, and
+ broadcasts the handle from the writer to all readers.
+
+ Args:
+ pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
+ group.
+ max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
+ max_chunks (int): Maximum number of chunks in the buffer.
+ writer_rank (int, optional): The global rank that will act as the writer.
+ Defaults to 0.
+ external_writer_handle (Handle, optional): Used when there is a handle
+ from an external Message Queue. If provided, use this handle to init
+ PG writer message queue instead of creating a new one. Defaults to None.
+ blocking (bool, optional): If True, blocks until all processes are ready.
+ Defaults to True.
+
+ Returns:
+ MessageQueue: The MessageQueue instance for the calling process.
+
+ """
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
@@ -617,23 +691,26 @@ class MessageQueue:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
-
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
- same_node_ranks = [i for i, s in enumerate(status) if s]
- n_reader = group_world_size - 1
- n_local_reader = len(same_node_ranks) - 1
- local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
- buffer_io: MessageQueue
if group_rank == writer_rank:
- buffer_io = MessageQueue(
- n_reader=n_reader,
- n_local_reader=n_local_reader,
- local_reader_ranks=local_reader_ranks,
- max_chunk_bytes=max_chunk_bytes,
- max_chunks=max_chunks,
- )
+ if external_writer_handle is not None:
+ buffer_io = MessageQueue.create_from_handle(
+ external_writer_handle, group_rank
+ )
+ else:
+ same_node_ranks = [i for i, s in enumerate(status) if s]
+ n_reader = group_world_size - 1
+ n_local_reader = len(same_node_ranks) - 1
+ local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
+ buffer_io = MessageQueue(
+ n_reader=n_reader,
+ n_local_reader=n_local_reader,
+ local_reader_ranks=local_reader_ranks,
+ max_chunk_bytes=max_chunk_bytes,
+ max_chunks=max_chunks,
+ )
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list(
@@ -651,5 +728,6 @@ class MessageQueue:
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
- buffer_io.wait_until_ready()
+ if blocking:
+ buffer_io.wait_until_ready()
return buffer_io
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index 3d4547c514532..1626f819af8b5 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -108,6 +108,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str
+ block_size: int
@dataclass
@@ -676,12 +677,13 @@ class NixlConnectorWorker:
mapping between local and remote TP workers.
"""
- tp_size: int
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
+ engine_id: EngineId
+ remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
@@ -709,6 +711,14 @@ class NixlConnectorWorker:
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
+ @property
+ def tp_size(self) -> int:
+ return self.remote_tp_size[self.engine_id]
+
+ @property
+ def block_size(self) -> int:
+ return self.remote_block_size[self.engine_id]
+
def tp_ratio(
self,
remote_tp_size: int,
@@ -725,6 +735,19 @@ class NixlConnectorWorker:
)
return self.tp_size // remote_tp_size
+ def block_size_ratio(
+ self,
+ remote_block_size: int,
+ ) -> float:
+ """
+ Calculate the block size ratio between local and remote TP.
+ """
+ assert self.block_size % remote_block_size == 0, (
+ f"Local block size {self.block_size} is not divisible "
+ f"by remote block size {remote_block_size} or vice versa."
+ )
+ return self.block_size // remote_block_size
+
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
@@ -732,6 +755,13 @@ class NixlConnectorWorker:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
+ def block_size_ratio_from_engine_id(
+ self,
+ remote_engine_id: EngineId,
+ ) -> float:
+ remote_block_size = self.remote_block_size[remote_engine_id]
+ return self.block_size_ratio(remote_block_size)
+
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
@@ -866,6 +896,7 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
+ self.src_xfer_side_handles: dict[int, int] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
@@ -925,15 +956,17 @@ class NixlConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
+ self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology(
- tp_size=self.world_size,
tp_rank=self.tp_rank,
+ engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
+ remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
@@ -987,9 +1020,13 @@ class NixlConnectorWorker:
)
# Register Remote agent.
+ assert metadata.block_size <= self.block_size, (
+ "nP > nD is not supported yet."
+ )
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
+
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
@@ -1124,6 +1161,14 @@ class NixlConnectorWorker:
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
+
+ # TODO (NickLucche): Get kernel_block_size in a cleaner way
+ # NHD default "view" for non-MLA cache
+ if self.device_type == "cpu":
+ block_size_position = -2
+ else:
+ block_size_position = -2 if self.use_mla else -3
+
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
@@ -1138,9 +1183,7 @@ class NixlConnectorWorker:
if base_addr in seen_base_addresses:
continue
- # TODO (NickLucche): Get kernel_block_size in a cleaner way
- # NHD default "view" for non-MLA cache
- kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
+ kernel_block_size = cache.shape[block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
@@ -1153,6 +1196,7 @@ class NixlConnectorWorker:
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
+ self._block_size[self.engine_id] = kernel_block_size
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
@@ -1217,43 +1261,10 @@ class NixlConnectorWorker:
self.num_regions *= 2
# Register local/src descr for NIXL xfer.
- blocks_data = []
- for i, base_addr in enumerate(seen_base_addresses):
- kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
- # NOTE With heter-TP, more blocks are prepared than what are
- # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
- # could create fewer, but then _get_block_descs_ids needs to
- # select agent_meta.num_blocks instead of self.num_blocks for
- # local descr, and that makes handling regular flow less clean.
- for block_id in range(self.num_blocks):
- block_offset = block_id * self.block_len_per_layer[i]
- addr = base_addr + block_offset
- # (addr, len, device id)
- blocks_data.append((addr, kv_block_len, self.device_id))
+ self.seen_base_addresses = seen_base_addresses
+ self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
- if self.kv_topo.is_kv_layout_blocks_first:
- # Separate and interleave K/V regions to maintain the same
- # descs ordering. This is needed for selecting contiguous heads
- # when split across TP ranks.
- for block_id in range(self.num_blocks):
- block_offset = block_id * self.block_len_per_layer[i]
- addr = base_addr + block_offset
- # Register addresses for V cache (K registered first).
- v_addr = addr + kv_block_len
- blocks_data.append((v_addr, kv_block_len, self.device_id))
- logger.debug(
- "Created %s blocks for src engine %s and rank %s on device id %s",
- len(blocks_data),
- self.engine_id,
- self.tp_rank,
- self.device_id,
- )
-
- descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
- # NIXL_INIT_AGENT to be used for preparations of local descs.
- self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
- "NIXL_INIT_AGENT", descs
- )
+ self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
@@ -1289,8 +1300,62 @@ class NixlConnectorWorker:
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
+ block_size=self.block_size,
)
+ def register_local_xfer_handler(
+ self,
+ block_size: int,
+ ) -> int:
+ """
+ Function used for register local xfer handler with local block_size or
+ Remote block_size.
+
+ When local block_size is same as remote block_size, we use local block_size
+ to register local_xfer_handler during init.
+
+ When remote block size is less than local block size, we need to use
+ register another local_xfer_handler using remote block len to ensure
+ data copy correctness.
+ """
+ block_size_ratio = self.block_size // block_size
+ blocks_data = []
+ for i, base_addr in enumerate(self.seen_base_addresses):
+ # The new block_len is using prefill block_len;
+ # and num_blocks is multiple with N
+ kv_block_len = (
+ self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio
+ )
+ block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio
+ num_blocks = self.num_blocks * block_size_ratio
+ for block_id in range(num_blocks):
+ block_offset = block_id * block_len_per_layer
+ addr = base_addr + block_offset
+ # (addr, len, device id)
+ blocks_data.append((addr, kv_block_len, self.device_id))
+
+ if self.kv_topo.is_kv_layout_blocks_first:
+ # Separate and interleave K/V regions to maintain the same
+ # descs ordering. This is needed for selecting contiguous heads
+ # when split across TP ranks.
+ for block_id in range(num_blocks):
+ block_offset = block_id * block_len_per_layer
+ addr = base_addr + block_offset
+ # Register addresses for V cache (K registered first).
+ v_addr = addr + kv_block_len
+ blocks_data.append((v_addr, kv_block_len, self.device_id))
+ logger.debug(
+ "Created %s blocks for src engine %s and rank %s on device id %s",
+ len(blocks_data),
+ self.engine_id,
+ self.tp_rank,
+ self.device_id,
+ )
+
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
+ # NIXL_INIT_AGENT to be used for preparations of local descs.
+ return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
+
def add_remote_agent(
self,
nixl_agent_meta: NixlAgentMetadata,
@@ -1349,6 +1414,8 @@ class NixlConnectorWorker:
### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
+ if engine_id not in self._block_size:
+ self._block_size[engine_id] = nixl_agent_meta.block_size
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata
@@ -1359,6 +1426,13 @@ class NixlConnectorWorker:
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
+ # Example:
+ # block_size_ratio > 1:
+ # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
+ # local origin:| 0| 1| 8| 12|
+ # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
+ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
+
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
@@ -1381,8 +1455,14 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
+ remote_kv_block_len = kv_block_len // block_size_ratio
+ if block_size_ratio > 1:
+ # using remote kv_block_len as transfer unit
+ kv_block_len = remote_kv_block_len
rank_offset = (
- self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
+ self.tp_rank % tp_ratio * remote_kv_block_len
+ if not replicates_kv_cache
+ else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
@@ -1417,6 +1497,13 @@ class NixlConnectorWorker:
remote_agent_name, descs
)
+ if block_size_ratio > 1:
+ # when prefill with smaller block_size, we need to init a
+ # new handler with same block_len to match
+ self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
+ self.register_local_xfer_handler(nixl_agent_meta.block_size)
+ )
+
return remote_agent_name
def _validate_remote_agent_handshake(
@@ -1433,6 +1520,9 @@ class NixlConnectorWorker:
assert nixl_agent_meta.attn_backend_name == self.backend_name
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
+ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
+ remote_engine_id
+ )
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
@@ -1463,33 +1553,26 @@ class NixlConnectorWorker:
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
- assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
- "KV cache sizes must match between P and D when replicated"
- )
- remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
+ for i in range(len(self.block_len_per_layer)):
+ assert (
+ self.block_len_per_layer[i] // block_size_ratio
+ == nixl_agent_meta.block_lens[i]
+ ), "KV cache sizes must match between P and D when replicated"
else:
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
- remote_block_size = remote_block_len // (
- self.slot_size_per_layer[0] * tp_ratio
- )
- if self.kv_topo.is_kv_layout_blocks_first:
- # With flashinfer, KV are sent in the same message.
- remote_block_size //= 2
- assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
+ assert (
+ remote_block_len
+ == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
+ ), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
- assert self.block_size == remote_block_size, (
- "Remote P worker with different page/block size is not supported "
- f"{self.block_size=}, {remote_block_size=}"
- )
-
# TP workers have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
@@ -1576,6 +1659,56 @@ class NixlConnectorWorker:
)
cache.index_copy_(0, indices, permuted_blocks)
+ def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
+ def _process_local_gt_remote(blocks_to_update, block_size_ratio):
+ n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
+ remote_block_size = block_size // block_size_ratio
+ n_blocks = block_size_ratio
+ # actual permute is to convert
+ # for local blocksize > remote blocksize
+ # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
+ # local block[0] = remote block[0, 1, 2, 3]
+ # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
+ # local is |h0-b0..................|h1-b0..................|...
+ # permute is to:
+ # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
+ # 2. permute => (H, nblocks, remoteN, D)
+ # 3. flatten => (H, localN, D)
+ permuted_blocks = (
+ blocks_to_update.reshape(
+ -1, n_blocks, n_kv_heads, remote_block_size, head_size
+ )
+ .permute(0, 2, 1, 3, 4)
+ .flatten(2, 3)
+ )
+ return permuted_blocks
+
+ if len(self.device_kv_caches) == 0:
+ return
+ split_k_and_v = not (
+ self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
+ )
+ sample_cache = list(self.device_kv_caches.values())[0][0]
+ for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
+ assert block_size_ratio > 1, "Only nP < nD supported currently."
+ block_ids_list = [[item for sublist in block_ids_list for item in sublist]]
+
+ for block_ids in block_ids_list:
+ indices = torch.tensor(block_ids, device=sample_cache.device)
+
+ for _, cache_or_caches in self.device_kv_caches.items():
+ cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
+ for cache in cache_list:
+ blocks_to_update = cache.index_select(0, indices)
+ # because kv_cache is always using original layout NHD as
+ # virtual shape while stride can be either HND / NHD at
+ # initialization.
+ # we need to firstly get physical view of the tensor
+ permuted_blocks = _process_local_gt_remote(
+ blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
+ ).permute(0, 2, 1, 3)
+ cache.index_copy_(0, indices, permuted_blocks)
+
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
@@ -1599,6 +1732,7 @@ class NixlConnectorWorker:
)
block_ids_to_permute = []
+ block_ids_for_blocksize_post_process = defaultdict(list)
for req_id in done_recving:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
@@ -1607,6 +1741,20 @@ class NixlConnectorWorker:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_physical_block_ids
+
+ # post processing for heteroblocksize
+ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
+ meta.remote_engine_id
+ )
+ if (
+ not self.use_mla
+ and block_size_ratio > 1
+ and self.kv_cache_layout == "HND"
+ ):
+ block_ids_for_blocksize_post_process[block_size_ratio].append(
+ meta.local_block_ids
+ )
+ self.blocksize_post_process(block_ids_for_blocksize_post_process)
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
@@ -1781,6 +1929,24 @@ class NixlConnectorWorker:
dst_engine_id: str,
request_id: str,
):
+ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
+ if block_size_ratio > 1:
+ local_block_ids = self.get_mapped_blocks(
+ np.asarray(local_block_ids), block_size_ratio
+ )
+ if len(local_block_ids) > len(remote_block_ids):
+ # NOTE:
+ # get_mapped_blocks will always expand block_ids for n times.
+ # ex:
+ # prefill block_ids with block_size as 4:
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ # Local decode block_ids with block_size as 16: [1, 2, 3]
+ # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+ # Then we clip local to align with prefill
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ local_block_ids = local_block_ids[: len(remote_block_ids)]
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
@@ -1823,7 +1989,10 @@ class NixlConnectorWorker:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
- local_xfer_side_handle = self.src_xfer_side_handle
+ remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
+ local_xfer_side_handle = self.src_xfer_side_handles.get(
+ remote_block_size, self.src_xfer_side_handle
+ )
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
@@ -1833,13 +2002,17 @@ class NixlConnectorWorker:
# Get descs ids.
local_block_descs_ids: np.ndarray
remote_block_descs_ids: np.ndarray
+
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
- dst_engine_id, remote_block_ids
+ dst_engine_id,
+ remote_block_ids,
)
local_block_descs_ids = self._get_block_descs_ids(
- self.engine_id, local_block_ids
+ self.engine_id,
+ local_block_ids,
+ block_size_ratio=block_size_ratio,
)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
@@ -1860,10 +2033,15 @@ class NixlConnectorWorker:
# Get descs ids for the layer.
layer_local_desc_ids = self._get_block_descs_ids(
- self.engine_id, layer_local_block_ids, layer_idx
+ dst_engine_id,
+ layer_local_block_ids,
+ layer_idx,
)
layer_remote_desc_ids = self._get_block_descs_ids(
- dst_engine_id, layer_remote_block_ids, layer_idx
+ self.engine_id,
+ layer_remote_block_ids,
+ layer_idx,
+ block_size_ratio=block_size_ratio,
)
local_descs_list.append(layer_local_desc_ids)
@@ -1905,8 +2083,31 @@ class NixlConnectorWorker:
self.nixl_wrapper.release_xfer_handle(handle)
self._failed_recv_reqs.add(request_id)
+ def get_mapped_blocks(self, block_ids, block_size_ratio):
+ """
+ Calculates the new set of block IDs by mapping every element
+ in the (potentially sparse) input array.
+ Example: block_ids=[0, 2], block_size_ratio=2
+ get_mapped_blocks 0 1 [2 3] 4 5
+ # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1||
+ # local is |h0-b0......||h1-b0......||h2-b0........
+ local_block_ids 0 [1] 2
+ """
+ if block_ids.size == 0:
+ return np.array([], dtype=np.int64)
+
+ start_ids = block_ids * block_size_ratio
+ offsets = np.arange(block_size_ratio)
+ mapped_2d = start_ids[:, None] + offsets[None, :]
+
+ return mapped_2d.flatten().astype(np.int64)
+
def _get_block_descs_ids(
- self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
+ self,
+ engine_id: str,
+ block_ids: list[int],
+ layer_idx: int | None = None,
+ block_size_ratio: float | None = None,
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
@@ -1929,6 +2130,8 @@ class NixlConnectorWorker:
region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id]
+ if block_size_ratio is not None:
+ num_blocks = int(num_blocks * block_size_ratio)
# Compute the desc ids for each block.
region_ids = region_ids[:, None]
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index c78e6a32733c1..852c4c644433f 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -385,6 +385,33 @@ class GroupCoordinator:
torch.ops._C, "init_shm_manager"
)
+ def create_mq_broadcaster(
+ self, writer_rank=0, external_writer_handle=None, blocking=True
+ ):
+ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
+
+ return MessageQueue.create_from_process_group(
+ self.cpu_group,
+ 1 << 22,
+ 6,
+ writer_rank=writer_rank,
+ external_writer_handle=external_writer_handle,
+ blocking=blocking,
+ )
+
+ def create_single_reader_mq_broadcasters(
+ self, reader_rank_in_group=0, blocking=False
+ ):
+ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
+
+ return MessageQueue.create_from_process_group_single_reader(
+ self.cpu_group,
+ 1 << 22,
+ 6,
+ reader_rank=self.ranks[reader_rank_in_group],
+ blocking=blocking,
+ )
+
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
@@ -997,6 +1024,7 @@ class GroupCoordinator:
_WORLD: GroupCoordinator | None = None
+_INNER_DP_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None
@@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
return _WORLD
+def get_inner_dp_world_group() -> GroupCoordinator:
+ assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
+ return _INNER_DP_WORLD
+
+
def init_world_group(
ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
@@ -1023,12 +1056,13 @@ def init_model_parallel_group(
backend: str,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
+ use_device_communicator: bool = True,
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
- use_device_communicator=True,
+ use_device_communicator=use_device_communicator,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
@@ -1143,7 +1177,14 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
- if (
+ if config is not None and config.parallel_config.nnodes > 1:
+ parallel_config = config.parallel_config
+ ip = parallel_config.master_addr
+ rank = parallel_config.data_parallel_rank * world_size + rank
+ world_size = parallel_config.world_size_across_dp
+ port = parallel_config.master_port
+ distributed_init_method = get_distributed_init_method(ip, port)
+ elif (
config is not None
and config.parallel_config.data_parallel_size > 1
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1164,6 +1205,14 @@ def init_distributed_environment(
distributed_init_method,
)
if not torch.distributed.is_initialized():
+ logger.info(
+ "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
+ world_size,
+ rank,
+ local_rank,
+ distributed_init_method,
+ backend,
+ )
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
@@ -1192,16 +1241,36 @@ def init_distributed_environment(
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
- global _WORLD, _NODE_COUNT
+ global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
- _NODE_COUNT = _node_count(_WORLD.cpu_group)
+ if config.parallel_config.nnodes > 1:
+ _NODE_COUNT = config.parallel_config.nnodes
+ else:
+ _NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size"
)
+ if config.parallel_config.nnodes_within_dp > 1:
+ if parallel_config.data_parallel_size > 1:
+ world_size_inner_dp = parallel_config.world_size
+ group_ranks = [
+ [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
+ for dp_rank in range(parallel_config.data_parallel_size)
+ ]
+ _INNER_DP_WORLD = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ use_message_queue_broadcaster=True,
+ group_name="inner_dp_world",
+ use_device_communicator=False,
+ )
+ else:
+ _INNER_DP_WORLD = _WORLD
def initialize_model_parallel(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 999ed780c20bf..ab6e5e594c239 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -384,6 +384,10 @@ class EngineArgs:
) = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
+ master_addr: str = ParallelConfig.master_addr
+ master_port: int = ParallelConfig.master_port
+ nnodes: int = ParallelConfig.nnodes
+ node_rank: int = ParallelConfig.node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
@@ -394,6 +398,7 @@ class EngineArgs:
data_parallel_address: str | None = None
data_parallel_rpc_port: int | None = None
data_parallel_hybrid_lb: bool = False
+ data_parallel_external_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
@@ -749,6 +754,10 @@ class EngineArgs:
"-pp",
**parallel_kwargs["pipeline_parallel_size"],
)
+ parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
+ parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
+ parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
+ parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
parallel_group.add_argument(
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
)
@@ -803,7 +812,14 @@ class EngineArgs:
help='Backend for data parallel, either "mp" or "ray".',
)
parallel_group.add_argument(
- "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
+ "--data-parallel-hybrid-lb",
+ "-dph",
+ **parallel_kwargs["data_parallel_hybrid_lb"],
+ )
+ parallel_group.add_argument(
+ "--data-parallel-external-lb",
+ "-dpe",
+ **parallel_kwargs["data_parallel_external_lb"],
)
parallel_group.add_argument(
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
@@ -1030,10 +1046,18 @@ class EngineArgs:
description=SchedulerConfig.__doc__,
)
scheduler_group.add_argument(
- "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
+ "--max-num-batched-tokens",
+ **{
+ **scheduler_kwargs["max_num_batched_tokens"],
+ "default": None,
+ },
)
scheduler_group.add_argument(
- "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
+ "--max-num-seqs",
+ **{
+ **scheduler_kwargs["max_num_seqs"],
+ "default": None,
+ },
)
scheduler_group.add_argument(
"--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
@@ -1055,7 +1079,11 @@ class EngineArgs:
"--scheduling-policy", **scheduler_kwargs["policy"]
)
scheduler_group.add_argument(
- "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
+ "--enable-chunked-prefill",
+ **{
+ **scheduler_kwargs["enable_chunked_prefill"],
+ "default": None,
+ },
)
scheduler_group.add_argument(
"--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
@@ -1428,12 +1456,56 @@ class EngineArgs:
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in headless mode"
)
-
- data_parallel_external_lb = self.data_parallel_rank is not None
+ assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
+ "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
+ )
+ assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
+ "nnodes > 1 is only supported with data_parallel_backend=mp"
+ )
+ inferred_data_parallel_rank = 0
+ if self.nnodes > 1:
+ world_size = (
+ self.data_parallel_size
+ * self.pipeline_parallel_size
+ * self.tensor_parallel_size
+ )
+ world_size_within_dp = (
+ self.pipeline_parallel_size * self.tensor_parallel_size
+ )
+ local_world_size = world_size // self.nnodes
+ assert world_size % self.nnodes == 0, (
+ f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
+ )
+ assert self.node_rank < self.nnodes, (
+ f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
+ )
+ inferred_data_parallel_rank = (
+ self.node_rank * local_world_size
+ ) // world_size_within_dp
+ if self.data_parallel_size > 1 and self.data_parallel_external_lb:
+ self.data_parallel_rank = inferred_data_parallel_rank
+ logger.info(
+ "Inferred data_parallel_rank %d from node_rank %d for external lb",
+ self.data_parallel_rank,
+ self.node_rank,
+ )
+ elif self.data_parallel_size_local is None:
+ # Infer data parallel size local for internal dplb:
+ self.data_parallel_size_local = max(
+ local_world_size // world_size_within_dp, 1
+ )
+ data_parallel_external_lb = (
+ self.data_parallel_external_lb or self.data_parallel_rank is not None
+ )
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
+ assert self.data_parallel_rank is not None, (
+ "data_parallel_rank or node_rank must be spefified if "
+ "data_parallel_external_lb is enable."
+ )
assert self.data_parallel_size_local in (1, None), (
- "data_parallel_size_local must be 1 when data_parallel_rank is set"
+ "data_parallel_size_local must be 1 or None when data_parallel_rank "
+ "is set"
)
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
@@ -1447,6 +1519,11 @@ class EngineArgs:
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
+ logger.warning(
+ "data_parallel_hybrid_lb is not eligible when "
+ "data_parallel_size_local = 1, autoswitch to "
+ "data_parallel_external_lb."
+ )
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
@@ -1454,7 +1531,15 @@ class EngineArgs:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
- self.data_parallel_rank = self.data_parallel_start_rank or 0
+ self.data_parallel_rank = (
+ self.data_parallel_start_rank or inferred_data_parallel_rank
+ )
+ if self.nnodes > 1:
+ logger.info(
+ "Inferred data_parallel_rank %d from node_rank %d",
+ self.data_parallel_rank,
+ self.node_rank,
+ )
else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
@@ -1484,7 +1569,9 @@ class EngineArgs:
"data_parallel_backend can only be ray or mp, got %s",
self.data_parallel_backend,
)
- data_parallel_address = ParallelConfig.data_parallel_master_ip
+ data_parallel_address = (
+ self.master_addr or ParallelConfig.data_parallel_master_ip
+ )
else:
data_parallel_address = self.data_parallel_address
@@ -1517,6 +1604,10 @@ class EngineArgs:
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
data_parallel_size_local=data_parallel_size_local,
+ master_addr=self.master_addr,
+ master_port=self.master_port,
+ nnodes=self.nnodes,
+ node_rank=self.node_rank,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py
index 24fcd9fe1cab9..462d2c4e50e73 100644
--- a/vllm/engine/protocol.py
+++ b/vllm/engine/protocol.py
@@ -125,7 +125,7 @@ class EngineClient(ABC):
...
@abstractmethod
- async def reset_prefix_cache(self, device: Device | None = None) -> None:
+ async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py
index 2678658dd1262..96608f360e17b 100644
--- a/vllm/entrypoints/cli/serve.py
+++ b/vllm/entrypoints/cli/serve.py
@@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor import Executor
+from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
@@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace):
if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
- host = parallel_config.data_parallel_master_ip
- port = engine_args.data_parallel_rpc_port # add to config too
- handshake_address = get_tcp_uri(host, port)
+ shutdown_requested = False
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
+ nonlocal shutdown_requested
logger.debug("Received %d signal.", signum)
- raise SystemExit
+ if not shutdown_requested:
+ shutdown_requested = True
+ raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
+ if parallel_config.node_rank_within_dp > 0:
+ from vllm.version import __version__ as VLLM_VERSION
+
+ # Run headless workers (for multi-node PP/TP).
+ host = parallel_config.master_addr
+ head_node_address = f"{host}:{parallel_config.master_port}"
+ logger.info(
+ "Launching vLLM (v%s) headless multiproc executor, "
+ "with head node address %s for torch.distributed process group.",
+ VLLM_VERSION,
+ head_node_address,
+ )
+
+ executor = MultiprocExecutor(vllm_config, monitor_workers=False)
+ executor.start_worker_monitor(inline=True)
+ return
+
+ host = parallel_config.data_parallel_master_ip
+ port = parallel_config.data_parallel_rpc_port
+ handshake_address = get_tcp_uri(host, port)
+
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.",
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 62717a7eacdf0..b0786bd355aa6 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -32,7 +32,6 @@ from vllm.config.model import (
TokenizerMode,
)
from vllm.engine.arg_utils import EngineArgs
-from vllm.engine.protocol import Device
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
@@ -1499,8 +1498,8 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
- def reset_prefix_cache(self, device: Device | None = None) -> None:
- self.llm_engine.reset_prefix_cache(device)
+ def reset_prefix_cache(self) -> None:
+ self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1):
"""
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 3e59af717d95c..3974f45a7135c 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -5,6 +5,7 @@ import hashlib
import importlib
import inspect
import json
+import logging
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
@@ -39,7 +40,7 @@ from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
-from vllm.engine.protocol import Device, EngineClient
+from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicError,
AnthropicErrorResponse,
@@ -1069,12 +1070,8 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
- device = None
- device_str = raw_request.query_params.get("device")
- if device_str is not None:
- device = Device[device_str.upper()]
- logger.info("Resetting prefix cache with specific %s...", str(device))
- await engine_client(raw_request).reset_prefix_cache(device)
+ logger.info("Resetting prefix cache...")
+ await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200)
@router.post("/reset_mm_cache")
@@ -2024,6 +2021,9 @@ async def run_server(args, **uvicorn_kwargs) -> None:
# Add process-specific prefix to stdout and stderr.
decorate_logs("APIServer")
+ # Suppress verbose logs from model_hosting_container_standards
+ logging.getLogger("model_hosting_container_standards").setLevel(logging.ERROR)
+
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
index 0453db58361a9..2b84c60a3b841 100644
--- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py
@@ -34,14 +34,34 @@ class KimiK2ToolParser(ToolParser):
str
] = [] # map what has been streamed for each tool so far to a list
+ # Section-level state management to prevent token leakage
+ self.in_tool_section: bool = False
+ self.token_buffer: str = ""
+ # Buffer size: empirical worst-case for longest marker (~30 chars) * 2
+ # + safety margin for unicode + partial overlap. Prevents unbounded growth.
+ self.buffer_max_size: int = 1024
+ self.section_char_count: int = 0 # Track characters processed in tool section
+ self.max_section_chars: int = 8192 # Force exit if section exceeds this
+ self._buffer_overflow_logged: bool = False # Log overflow once per session
+
+ # Support both singular and plural variants
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
+ self.tool_calls_start_token_variants: list[str] = [
+ "<|tool_calls_section_begin|>",
+ "<|tool_call_section_begin|>", # singular variant
+ ]
+ self.tool_calls_end_token_variants: list[str] = [
+ "<|tool_calls_section_end|>",
+ "<|tool_call_section_end|>", # singular variant
+ ]
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_regex = re.compile(
- r"<\|tool_call_begin\|>\s*(?P.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*?)\s*<\|tool_call_end\|>"
+ r"<\|tool_call_begin\|>\s*(?P[^<]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P(?:(?!<\|tool_call_begin\|>).)*?)\s*<\|tool_call_end\|>",
+ re.DOTALL,
)
self.stream_tool_call_portion_regex = re.compile(
@@ -58,6 +78,18 @@ class KimiK2ToolParser(ToolParser):
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
+ # Get token IDs for all variants
+ self.tool_calls_start_token_ids: list[int] = [
+ tid
+ for variant in self.tool_calls_start_token_variants
+ if (tid := self.vocab.get(variant)) is not None
+ ]
+ self.tool_calls_end_token_ids: list[int] = [
+ tid
+ for variant in self.tool_calls_end_token_variants
+ if (tid := self.vocab.get(variant)) is not None
+ ]
+
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
@@ -70,6 +102,51 @@ class KimiK2ToolParser(ToolParser):
"tokens in the tokenizer!"
)
+ def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
+ """
+ Check for section begin/end markers in text and strip them.
+ Returns: (cleaned_text, found_section_begin, found_section_end)
+ """
+ found_begin = False
+ found_end = False
+ cleaned = text
+
+ # Check for section begin markers (any variant)
+ for variant in self.tool_calls_start_token_variants:
+ if variant in cleaned:
+ cleaned = cleaned.replace(variant, "")
+ found_begin = True
+
+ # Check for section end markers (any variant)
+ for variant in self.tool_calls_end_token_variants:
+ if variant in cleaned:
+ cleaned = cleaned.replace(variant, "")
+ found_end = True
+
+ return cleaned, found_begin, found_end
+
+ def _reset_section_state(self) -> None:
+ """Reset state when exiting tool section."""
+ self.in_tool_section = False
+ self.token_buffer = ""
+ self.section_char_count = 0
+
+ def reset_streaming_state(self) -> None:
+ """
+ Reset all streaming state. Call this between requests to prevent
+ state leakage when parser instance is reused.
+ """
+ # Reset section state
+ self._reset_section_state()
+
+ # Reset parent class state
+ self.current_tool_name_sent = False
+ self.prev_tool_call_arr = []
+ self.current_tool_id = -1
+ self.streamed_args_for_tool = []
+
+ logger.debug("Streaming state reset")
+
def extract_tool_calls(
self,
model_output: str,
@@ -131,13 +208,94 @@ class KimiK2ToolParser(ToolParser):
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
- # check to see if we should be streaming a tool call - is there a
- if self.tool_calls_start_token_id not in current_token_ids:
- logger.debug("No tool call tokens found!")
- return DeltaMessage(content=delta_text)
- delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(
- self.tool_calls_end_token, ""
+
+ # Flag to defer section exit until after tool parsing completes
+ deferred_section_exit = False
+
+ # Add delta to buffer for split marker detection
+ self.token_buffer += delta_text
+
+ # Enforce buffer size limit to prevent memory issues
+ if len(self.token_buffer) > self.buffer_max_size:
+ if not self._buffer_overflow_logged:
+ logger.warning(
+ "Token buffer exceeded max size (%d bytes), flushing excess. "
+ "This may indicate very long markers or unusual tokenization.",
+ self.buffer_max_size,
+ )
+ self._buffer_overflow_logged = True
+ # Keep only the most recent content that might contain partial markers
+ self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
+
+ # Check buffer for section markers (handles split tokens)
+ buffered_text, found_section_begin, found_section_end = (
+ self._check_and_strip_markers(self.token_buffer)
)
+
+ # Track section state transitions
+ if found_section_begin and not self.in_tool_section:
+ logger.debug("Entering tool section")
+ self.in_tool_section = True
+ self.token_buffer = buffered_text # Use cleaned buffer
+ self.section_char_count = 0 # Reset counter for new section
+ if found_section_end and self.in_tool_section:
+ logger.debug("Detected section end marker")
+ # CRITICAL: Don't exit early if tool_call_end is in this chunk.
+ # Tool parser must emit final arguments/close first to avoid dropping
+ # the final tool update and leaking tokens into reasoning channel.
+ has_tool_end = self.tool_call_end_token_id in delta_token_ids
+ if has_tool_end:
+ # Defer exit until after tool parsing completes
+ deferred_section_exit = True
+ logger.debug("Deferring section exit: tool_call_end in same chunk")
+ self.token_buffer = buffered_text
+ else:
+ # No tool call ending, safe to exit immediately
+ logger.debug("Exiting tool section")
+ remaining = buffered_text
+ self._reset_section_state()
+ # Return remaining text as reasoning content if non-empty
+ if remaining.strip():
+ return DeltaMessage(content=remaining)
+ # Return empty delta to maintain function contract
+ # (always returns DeltaMessage)
+ return DeltaMessage(content="")
+ else:
+ self.token_buffer = buffered_text
+
+ # Check if any variant of section start token is in current_token_ids
+ has_section_token = any(
+ tid in current_token_ids for tid in self.tool_calls_start_token_ids
+ )
+
+ # Early return: if no section token detected yet, return as reasoning content
+ if not has_section_token and not self.in_tool_section:
+ logger.debug("No tool call tokens found!")
+ # Don't clear buffer - it needs to accumulate partial markers across deltas
+ # Buffer overflow is already protected by lines 215-224
+ return DeltaMessage(content=delta_text)
+
+ # Strip section markers from delta_text for subsequent processing
+ # NOTE: This preprocessing happens BEFORE the regex-based tool call
+ # parsing (from PR #24847) to ensure markers are removed cleanly
+ # before pattern matching. No double-stripping occurs because
+ # section markers and tool call markers are distinct.
+ delta_text, _, _ = self._check_and_strip_markers(delta_text)
+
+ # Error recovery: If in tool section for too long, force exit
+ if self.in_tool_section:
+ self.section_char_count += len(delta_text)
+ if self.section_char_count > self.max_section_chars:
+ logger.warning(
+ "Tool section exceeded max length (%d chars), forcing exit. "
+ "This may indicate malformed model output.",
+ self.max_section_chars,
+ )
+ self._reset_section_state()
+ # Deferred exit already handled by forced exit above
+ # Return remaining content as reasoning (or empty delta if no content)
+ return DeltaMessage(content=delta_text if delta_text.strip() else "")
+
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
@@ -158,6 +316,16 @@ class KimiK2ToolParser(ToolParser):
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
+ # CRITICAL FIX: Suppress content if in tool section but
+ # no tool calls started
+ if self.in_tool_section and cur_tool_start_count == 0:
+ logger.debug(
+ "In tool section but no tool calls started yet. "
+ "Suppressing: %s",
+ delta_text,
+ )
+ # Return empty delta to maintain iterator contract
+ return DeltaMessage(content="")
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
@@ -209,6 +377,9 @@ class KimiK2ToolParser(ToolParser):
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
+ # Handle deferred section exit before returning
+ if deferred_section_exit and self.in_tool_section:
+ self._reset_section_state()
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
@@ -218,6 +389,9 @@ class KimiK2ToolParser(ToolParser):
else diff
)
if '"}' not in delta_text:
+ # Handle deferred section exit before returning
+ if deferred_section_exit and self.in_tool_section:
+ self._reset_section_state()
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
@@ -227,6 +401,10 @@ class KimiK2ToolParser(ToolParser):
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
+ # Handle deferred section exit before returning
+ if deferred_section_exit and self.in_tool_section:
+ logger.debug("Completing deferred section exit")
+ self._reset_section_state()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
@@ -240,9 +418,19 @@ class KimiK2ToolParser(ToolParser):
# case -- otherwise we're just generating text
else:
+ # Check if we're in tool section - if so, suppress
+ if self.in_tool_section:
+ logger.debug("In tool section, suppressing text generation")
+ # Handle deferred section exit before returning
+ if deferred_section_exit:
+ self._reset_section_state()
+ return DeltaMessage(content="")
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
+ # Handle deferred section exit before returning
+ if deferred_section_exit and self.in_tool_section:
+ self._reset_section_state()
return delta
current_tool_call = dict()
@@ -390,6 +578,11 @@ class KimiK2ToolParser(ToolParser):
else:
self.prev_tool_call_arr.append(current_tool_call)
+ # Handle deferred section exit after tool parsing completes
+ if deferred_section_exit and self.in_tool_section:
+ logger.debug("Completing deferred section exit")
+ self._reset_section_state()
+
return delta
except Exception:
diff --git a/vllm/envs.py b/vllm/envs.py
index 7987e5fb83fdf..6d92d5afee501 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -50,7 +50,6 @@ if TYPE_CHECKING:
VLLM_CPU_KVCACHE_SPACE: int | None = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
- VLLM_CPU_MOE_PREPACK: bool = True
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
@@ -225,7 +224,6 @@ if TYPE_CHECKING:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
- VLLM_FLAT_LOGPROBS: bool = False
def get_default_cache_root():
@@ -423,7 +421,7 @@ def get_vllm_port() -> int | None:
raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err
-# The begin-* and end* here are used by the documentation generator
+# The start-* and end* here are used by the documentation generator
# to extract the used env vars.
# --8<-- [start:env-vars-definition]
@@ -666,10 +664,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
)
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
else None,
- # (CPU backend only) whether to use prepack for MoE layer. This will be
- # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
- # need to set this to "0" (False).
- "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# If the env var is set, Ray Compiled Graph uses the specified
@@ -1499,11 +1493,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
- # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
- # the original list[dict[int, Logprob]] approach.
- # After enabled, PromptLogprobs and SampleLogprobs would populated as
- # FlatLogprobs.
- "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
}
# --8<-- [end:env-vars-definition]
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index 44bc2a4cda311..25fb7181a8f29 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext:
return _forward_context
+def is_forward_context_available() -> bool:
+ return _forward_context is not None
+
+
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py
index 80d5322a34c3a..839c13868a16c 100644
--- a/vllm/inputs/preprocess.py
+++ b/vllm/inputs/preprocess.py
@@ -348,18 +348,15 @@ class InputPreprocessor:
)
inputs: TokenInputs | MultiModalInputs
- if self.model_config.is_multimodal_model:
+ if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
- parsed_content.get("multi_modal_data") or {},
+ multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
- if parsed_content.get("multi_modal_data"):
- raise ValueError("This model does not support multimodal inputs")
-
inputs = token_inputs(prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
@@ -377,18 +374,15 @@ class InputPreprocessor:
prompt_text = parsed_content["prompt"]
inputs: TokenInputs | MultiModalInputs
- if self.model_config.is_multimodal_model:
+ if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_text,
- parsed_content.get("multi_modal_data") or {},
+ multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
- if parsed_content.get("multi_modal_data"):
- raise ValueError("This model does not support multimodal inputs")
-
prompt_token_ids = self._tokenize_prompt(
prompt_text,
tokenization_kwargs=tokenization_kwargs,
diff --git a/vllm/logprobs.py b/vllm/logprobs.py
index a34398db2c960..6a820308f523f 100644
--- a/vllm/logprobs.py
+++ b/vllm/logprobs.py
@@ -5,8 +5,6 @@ from collections.abc import Iterable, Iterator, MutableSequence
from dataclasses import dataclass, field
from typing import overload
-import vllm.envs as envs
-
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
@@ -161,17 +159,17 @@ PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
-def create_prompt_logprobs() -> PromptLogprobs:
+def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
- logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
+ logprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
-def create_sample_logprobs() -> SampleLogprobs:
+def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
- return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
+ return FlatLogprobs() if flat_logprobs else []
def append_logprobs_for_next_position(
diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
index 893972144e99a..e2dd47dbb4e64 100644
--- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@@ -154,7 +154,7 @@ def _fused_moe_lora_kernel(
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
- # GDC wait waits for ALL programs in the the prior kernel to complete
+ # GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py
index 746a543ab827d..7920d117de5e0 100644
--- a/vllm/model_executor/layers/batch_invariant.py
+++ b/vllm/model_executor/layers/batch_invariant.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import contextlib
import os
-from collections import namedtuple
from collections.abc import Callable
from functools import cache
from typing import Any
@@ -725,10 +723,6 @@ _original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
-def is_batch_invariant_mode_enabled():
- return _batch_invariant_MODE
-
-
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
@@ -791,73 +785,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
-def disable_batch_invariant_mode():
- global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
- global _original_fp16_reduction_precision, _original_bf16_reduction_precision
- global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
- if not _batch_invariant_MODE:
- return
-
- if _batch_invariant_LIB is not None:
- _batch_invariant_LIB._destroy()
- if _original_torch_bmm is not None:
- torch.bmm = _original_torch_bmm
- _original_torch_bmm = None
-
- if _original_bf16_reduction_precision is not None:
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
- _original_bf16_reduction_precision
- )
- _original_bf16_reduction_precision = None
- if _original_fp16_reduction_precision is not None:
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
- _original_fp16_reduction_precision
- )
- _original_fp16_reduction_precision = None
-
- torch.backends.cuda.preferred_blas_library(backend="default")
-
- if not is_torch_equal_or_newer("2.10.0.dev"):
- # Set cublas env vars to previous results. If previous results are None,
- # that means the env vars were not set, so we should remove them.
- if _original_cublas_workspace_cfg:
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
- elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
- del os.environ["CUBLAS_WORKSPACE_CONFIG"]
-
- if _original_cublaslt_workspace_size:
- os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
- elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
- del os.environ["CUBLASLT_WORKSPACE_SIZE"]
-
- _original_cublas_workspace_cfg = None
- _original_cublaslt_workspace_size = None
-
- _batch_invariant_MODE = False
- _batch_invariant_LIB = None
-
-
-@contextlib.contextmanager
-def set_batch_invariant_mode(enabled: bool = True):
- global _batch_invariant_MODE, _batch_invariant_LIB
- old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
- if enabled:
- enable_batch_invariant_mode()
- else:
- disable_batch_invariant_mode()
- yield
- if _batch_invariant_LIB is not None:
- _batch_invariant_LIB._destroy()
- _batch_invariant_MODE, _batch_invariant_LIB = old_data
-
-
-AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
-
-
-def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
- return AttentionBlockSize(block_m=16, block_n=16)
-
-
@cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py
index e6f2d2990c241..8d51e5bd9920a 100644
--- a/vllm/model_executor/layers/conv.py
+++ b/vllm/model_executor/layers/conv.py
@@ -3,6 +3,7 @@
"""Conv Layer Class."""
import math
+from typing import Literal
import torch
import torch.nn as nn
@@ -23,11 +24,11 @@ class ConvLayerBase(CustomOp):
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
- padding: int | tuple[int, ...] = 0,
+ padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
- padding_mode: str = "zeros",
+ padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
@@ -36,6 +37,22 @@ class ConvLayerBase(CustomOp):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
+ valid_padding_strings = {"same", "valid"}
+ if isinstance(padding, str) and padding not in valid_padding_strings:
+ raise ValueError(
+ f"Invalid padding string '{padding}'. "
+ f"Expected one of {valid_padding_strings}."
+ )
+
+ if padding == "same":
+ padding = (
+ kernel_size // 2
+ if isinstance(kernel_size, int)
+ else tuple(k // 2 for k in kernel_size)
+ )
+ elif padding == "valid":
+ padding = 0
+
kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
@@ -45,6 +62,9 @@ class ConvLayerBase(CustomOp):
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
+ if padding == "same" and any(s != 1 for s in stride):
+ raise ValueError("padding='same' is not supported for strided convolutions")
+
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index 79c92eb48612d..53362277dae8a 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -5,6 +5,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
-from vllm.utils.math_utils import cdiv
+from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
@@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output)
+ def estimate_expected_m(
+ self, global_num_experts: int, max_tokens_per_expert: int, topk: int
+ ) -> int:
+ dp_meta = (
+ get_forward_context().dp_metadata
+ if is_forward_context_available()
+ else None
+ )
+ if dp_meta is None:
+ logger.warning_once(
+ "DPMetadata unavailable. Defaulting expected_m to "
+ f"{max_tokens_per_expert}.",
+ scope="local",
+ )
+ return max_tokens_per_expert
+
+ total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
+ total_num_tokens_replicated = total_num_tokens * topk
+
+ # Assume even load balancing
+ assert global_num_experts != 0
+ estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
+ # clamp estimate
+ estimate = max(estimate, 16)
+ estimate = min(max_tokens_per_expert, estimate)
+ return estimate
+
def apply(
self,
output: torch.Tensor,
@@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
- # (from deepgemm docs) : A value hint (which is a value on CPU)
- # for the M expectation of each batch, correctly setting this value
- # may lead to better performance.
- expected_m = max_num_tokens
+ expected_m = self.estimate_expected_m(
+ global_num_experts=global_num_experts,
+ max_tokens_per_expert=max_num_tokens,
+ topk=topk_ids.size(-1),
+ )
+
fp8_m_grouped_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
new file mode 100644
index 0000000000000..555d173644522
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
@@ -0,0 +1,164 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 4,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "256": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "512": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 4,
+ "num_warps": 4,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 4,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 2,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 4,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 2,
+ "waves_per_eu": 0
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
index 23ace3408562a..572307052b489 100644
--- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
@@ -6,7 +6,6 @@ import torch
from torch.nn import functional as F
from vllm import _custom_ops as ops
-from vllm import envs
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
@@ -130,54 +129,6 @@ def select_experts(
)
-class IPEXFusedMOE:
- def __init__(self, layer: torch.nn.Module) -> None:
- import intel_extension_for_pytorch as ipex
-
- layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
- layer.w13_weight,
- layer.w2_weight,
- use_prepack=envs.VLLM_CPU_MOE_PREPACK,
- )
-
- def __call__(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- use_grouped_topk: bool,
- top_k: int,
- router_logits: torch.Tensor,
- renormalize: bool,
- topk_group: int | None = None,
- num_expert_group: int | None = None,
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- custom_routing_function: Callable | None = None,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0,
- e_score_correction_bias: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- activation: str = "silu",
- ) -> torch.Tensor:
- assert activation == "silu", f"{activation} is not supported."
- assert not apply_router_weight_on_input
- assert routed_scaling_factor == 1.0, (
- f"routed_scaling_factor {routed_scaling_factor} is not supported."
- )
- return layer.ipex_fusion(
- x,
- use_grouped_topk,
- top_k,
- router_logits,
- renormalize,
- topk_group,
- num_expert_group,
- custom_routing_function,
- scoring_func,
- e_score_correction_bias,
- )
-
-
class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index 943695f921ad3..f864634c66176 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
- assert activation == "silu", (
- "Only activation silu is supported in FlashInferExperts"
+ from flashinfer.fused_moe.core import ActivationType
+
+ activation_str_to_value_map = {
+ "silu": ActivationType.Swiglu, # This is the default
+ "relu2_no_mul": ActivationType.Relu2,
+ }
+ assert activation in activation_str_to_value_map, (
+ f"{activation=} missing from {activation_str_to_value_map.keys()=}"
)
# Select quantization metadata based on FP8 format/path
@@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
+ activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index ce56887f1c26d..2e0376553b913 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -260,7 +260,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight.copy_(packed_w2_weight)
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
else:
- layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
+ layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py
index fb45afa33dad6..57313990b8206 100644
--- a/vllm/model_executor/layers/mamba/mamba_mixer2.py
+++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py
@@ -426,6 +426,10 @@ class MambaMixer2(MambaBase, CustomOp):
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
+ conv_weights = self.conv1d.weight.view(
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
+ )
+ self.register_buffer("conv_weights", conv_weights, persistent=False)
# - these are TPed by heads to reduce the size of the
# temporal shape
@@ -459,6 +463,17 @@ class MambaMixer2(MambaBase, CustomOp):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
+ # - get hidden_states, B and C after depthwise convolution.
+ self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
+ hidden_states_B_C,
+ [
+ self.intermediate_size // self.tp_size,
+ self.groups_ssm_state_size // self.tp_size,
+ self.groups_ssm_state_size // self.tp_size,
+ ],
+ dim=-1,
+ )
+
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -470,10 +485,24 @@ class MambaMixer2(MambaBase, CustomOp):
self.cache_config = cache_config
self.prefix = prefix
+ # Pre-compute sizes for forward pass
+ self.tped_intermediate_size = self.intermediate_size // self.tp_size
+ self.tped_conv_size = self.conv_dim // self.tp_size
+ self.tped_dt_size = self.num_heads // self.tp_size
+
+ self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
+ hidden_states_B_C,
+ [
+ self.tped_intermediate_size,
+ self.groups_ssm_state_size // self.tp_size,
+ self.groups_ssm_state_size // self.tp_size,
+ ],
+ dim=-1,
+ )
+
def forward_native(
self,
hidden_states: torch.Tensor,
- output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
pass
@@ -481,22 +510,55 @@ class MambaMixer2(MambaBase, CustomOp):
def forward(
self,
hidden_states: torch.Tensor,
- output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
- torch.ops.vllm.mamba_mixer2(
- hidden_states,
- output,
- self.prefix,
- mup_vector,
+ # 1. Gated MLP's linear projection
+ projected_states, _ = self.in_proj(hidden_states)
+ if mup_vector is not None:
+ projected_states = projected_states * mup_vector
+
+ # 2. Prepare inputs for conv + SSM
+ ssm_output = torch.empty(
+ [
+ hidden_states.shape[0],
+ (self.num_heads // self.tp_size) * self.head_dim,
+ ],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
)
- def forward_cuda(
+ # 3. conv + SSM
+ # (split `projected_states` into hidden_states_B_C, dt in the custom op to
+ # ensure it is not treated as an intermediate tensor by torch compile)
+ torch.ops.vllm.mamba_mixer2(
+ projected_states,
+ ssm_output,
+ self.prefix,
+ )
+
+ # 4. gated MLP
+ # GatedRMSNorm internally applying SiLU to the gate
+ # SiLU is applied internally before normalization, unlike standard
+ # norm usage
+ gate = projected_states[..., : self.tped_intermediate_size]
+ hidden_states = self.norm(ssm_output, gate)
+
+ # 5. Final linear projection
+ output, _ = self.out_proj(hidden_states)
+
+ return output
+
+ def conv_ssm_forward(
self,
- hidden_states: torch.Tensor,
+ projected_states: torch.Tensor,
output: torch.Tensor,
- mup_vector: torch.Tensor | None = None,
):
+ hidden_states_B_C, dt = torch.split(
+ projected_states[..., self.tped_intermediate_size :],
+ [self.tped_conv_size, self.tped_dt_size],
+ dim=-1,
+ )
+
forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
@@ -524,46 +586,13 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
- # 1. Gated MLP's linear projection
- projected_states, _ = self.in_proj(hidden_states)
-
- if mup_vector is not None:
- projected_states = projected_states * mup_vector
-
- gate, hidden_states_B_C, dt = torch.split(
- projected_states,
- [
- self.intermediate_size // self.tp_size,
- self.conv_dim // self.tp_size,
- self.num_heads // self.tp_size,
- ],
- dim=-1,
- )
-
- conv_weights = self.conv1d.weight.view(
- self.conv1d.weight.size(0), self.conv1d.weight.size(2)
- )
-
- # - get hidden_states, B and C after depthwise convolution.
- split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
- hidden_states_B_C,
- [
- self.intermediate_size // self.tp_size,
- self.groups_ssm_state_size // self.tp_size,
- self.groups_ssm_state_size // self.tp_size,
- ],
- dim=-1,
- )
-
if attn_metadata is None:
# profile run
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
- hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
- hidden_states = self.norm(hidden_states, gate)
- out, _ = self.out_proj(hidden_states)
- return out
+ hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
+ return hidden_states
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
@@ -622,18 +651,8 @@ class MambaMixer2(MambaBase, CustomOp):
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
- # Preallocate output tensor to avoid memcpy cost for merging prefill
- # and decode outputs
- preallocated_ssm_out = torch.empty(
- [
- num_prefill_tokens + num_decodes,
- (self.num_heads // self.tp_size) * self.head_dim,
- ],
- dtype=hidden_states.dtype,
- device=hidden_states.device,
- )
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
- preallocated_ssm_out,
+ output[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
@@ -658,7 +677,7 @@ class MambaMixer2(MambaBase, CustomOp):
) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
x,
- conv_weights,
+ self.conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
@@ -673,7 +692,9 @@ class MambaMixer2(MambaBase, CustomOp):
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
- hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
+ hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
+ hidden_states_B_C_p
+ )
# 3. State Space Model sequence transformation
initial_states = None
@@ -815,7 +836,7 @@ class MambaMixer2(MambaBase, CustomOp):
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
conv_state,
- conv_weights,
+ self.conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
@@ -823,7 +844,9 @@ class MambaMixer2(MambaBase, CustomOp):
initial_state_idx=block_idx_last_computed_token_d,
)
- hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
+ hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
+ hidden_states_B_C_d
+ )
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
@@ -861,15 +884,6 @@ class MambaMixer2(MambaBase, CustomOp):
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
- # 4. gated MLP
- # GatedRMSNorm internally applying SiLU to the gate
- # SiLU is applied internally before normalization, unlike standard
- # norm usage
- hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
-
- # 5. Final linear projection
- output[:num_actual_tokens], _ = self.out_proj(hidden_states)
-
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
@@ -901,21 +915,19 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_mixer2(
- hidden_states: torch.Tensor,
+ projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
- mup_vector: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
- self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
+ self.conv_ssm_forward(projected_states=projected_states, output=output)
def mamba_mixer2_fake(
- hidden_states: torch.Tensor,
+ projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
- mup_vector: torch.Tensor | None = None,
) -> None:
return
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index bb42b10f87186..18aaae394f935 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -38,6 +38,8 @@ QuantizationMethods = Literal[
"inc",
"mxfp4",
"petit_nvfp4",
+ "cpu_gptq",
+ "cpu_awq",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -107,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
+ from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
@@ -159,6 +162,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
+ "cpu_gptq": CPUGPTQConfig,
+ "cpu_awq": CPUAWQConfig,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py
new file mode 100644
index 0000000000000..bf643f55f1b9a
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/cpu_wna16.py
@@ -0,0 +1,625 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Any, Optional
+
+import torch
+from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
+
+from vllm._custom_ops import (
+ cpu_gemm_wna16,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import (
+ LinearBase,
+ LinearMethodBase,
+ UnquantizedLinearMethod,
+)
+from vllm.model_executor.layers.quantization import QuantizationMethods
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig,
+ QuantizeMethodBase,
+)
+from vllm.model_executor.layers.quantization.utils.gptq_utils import (
+ get_linear_quant_method,
+)
+from vllm.model_executor.layers.quantization.utils.marlin_utils import (
+ marlin_repeat_scales_on_all_ranks,
+)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ is_layer_skipped,
+ pack_cols,
+ unpack_cols,
+)
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.models.utils import WeightsMapper
+from vllm.model_executor.parameter import (
+ ChannelQuantScaleParameter,
+ GroupQuantScaleParameter,
+ PackedColumnParameter,
+ PackedvLLMParameter,
+ RowvLLMParameter,
+)
+from vllm.model_executor.utils import set_weight_attrs
+from vllm.platforms import current_platform
+from vllm.transformers_utils.config import get_safetensors_params_metadata
+from vllm.utils.collection_utils import is_list_of
+
+logger = init_logger(__name__)
+
+
+class CPUGPTQConfig(QuantizationConfig):
+ """Config class for CPU GPTQ quant"""
+
+ def __init__(
+ self,
+ weight_bits: int,
+ group_size: int,
+ desc_act: bool,
+ is_sym: bool,
+ lm_head_quantized: bool,
+ dynamic: dict[str, dict[str, int | bool]],
+ full_config: dict[str, Any],
+ modules_in_block_to_quantize: list[str] | None = None,
+ ) -> None:
+ super().__init__()
+ if desc_act and group_size == -1:
+ # In this case, act_order == True is the same as act_order == False
+ # (since we have only one group per output channel)
+ desc_act = False
+
+ # GPTQModel use `dynamic` config property to allow per module
+ # quantization config so each module can be individually optimized.
+ # Format is dict[str, dict] where key is a regex string that can
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
+ # matching of a module.
+ # Default to positive match, override base quant config mode, if no
+ # prefix is used. Value is in dict format of field key and override
+ # value.
+ # Negative matching will skip quantization init for this module
+ # entirely:
+ # non-quantized inference. More details and quantization examples can be
+ # found at: https://github.com/ModelCloud/GPTQModel
+ # Example:
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
+ # dynamic = {
+ # #`.*\.` matches the layers_node prefix
+ # # positive match layer 10-15
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
+ # # positive match layer 16-21
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
+ # }
+ assert weight_bits == 4
+ self.dynamic = dynamic
+ self.weight_bits = weight_bits
+ self.is_sym = is_sym
+ self.pack_factor = 32 // weight_bits # packed into int32
+ self.group_size = group_size
+ self.desc_act = desc_act
+ self.lm_head_quantized = lm_head_quantized
+ self.full_config = full_config
+ self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
+
+ def __repr__(self) -> str:
+ return (
+ f"CPUWNA16Config("
+ f"group_size={self.group_size}, "
+ f"desc_act={self.desc_act}, "
+ f"lm_head_quantized={self.lm_head_quantized}, "
+ f"dynamic={self.dynamic}, "
+ f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
+ )
+
+ @classmethod
+ def get_name(cls) -> QuantizationMethods:
+ return "cpu_gptq"
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
+ return [torch.half, torch.bfloat16]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return -1
+
+ @classmethod
+ def get_config_filenames(cls) -> list[str]:
+ return ["quantize_config.json"]
+
+ @classmethod
+ def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig":
+ weight_bits = cls.get_from_keys(config, ["bits"])
+ desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
+ group_size = cls.get_from_keys(config, ["group_size"])
+ is_sym = cls.get_from_keys(config, ["sym"])
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
+ modules_in_block_to_quantize = cls.get_from_keys_or(
+ config, ["modules_in_block_to_quantize"], default=None
+ )
+ return cls(
+ weight_bits,
+ group_size,
+ desc_act,
+ is_sym,
+ lm_head_quantized,
+ dynamic,
+ config,
+ modules_in_block_to_quantize,
+ )
+
+ @classmethod
+ def override_quantization_method(
+ cls, hf_quant_cfg, user_quant
+ ) -> QuantizationMethods | None:
+ quant_method = hf_quant_cfg.get("quant_method", "").lower()
+ if current_platform.is_cpu() and (quant_method == "gptq"):
+ return cls.get_name()
+ return None
+
+ def get_quant_method(
+ self, layer: torch.nn.Module, prefix: str
+ ) -> Optional["QuantizeMethodBase"]:
+ return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore
+
+ def apply_vllm_mapper(self, hf_to_vllm_mapper):
+ if self.modules_in_block_to_quantize is not None:
+ self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
+ self.modules_in_block_to_quantize
+ )
+
+ def maybe_update_config(self, model_name: str, revision: str | None = None):
+ if self.modules_in_block_to_quantize:
+ if is_list_of(self.modules_in_block_to_quantize, list):
+ # original modules_in_block_to_quantize: list[list[str]]
+ # flatten original modules_in_block_to_quantize
+ self.modules_in_block_to_quantize = [
+ item
+ for sublist in self.modules_in_block_to_quantize
+ for item in sublist
+ ]
+ return
+
+ unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ metadata = get_safetensors_params_metadata(model_name, revision=revision)
+ quant_layers: set[str] = {
+ param_name.rsplit(".", 1)[0]
+ for param_name, info in metadata.items()
+ if (dtype := info.get("dtype", None))
+ and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
+ }
+ self.modules_in_block_to_quantize = list(quant_layers)
+
+
+class CPUGPTQLinearMethod(LinearMethodBase):
+ """Linear method for GPTQ on CPU.
+
+ Args:
+ quant_config: The CPUWNA16 quantization config.
+ """
+
+ def __init__(self, quant_config: CPUGPTQConfig) -> None:
+ self.quant_config = quant_config
+ assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU"
+
+ def create_weights(
+ self,
+ layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: list[int],
+ input_size: int,
+ output_size: int,
+ params_dtype: torch.dtype,
+ **extra_weight_attrs,
+ ) -> None:
+ output_size_per_partition = sum(output_partition_sizes)
+ assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0
+ assert output_size_per_partition % 32 == 0
+ assert input_size_per_partition % 32 == 0
+
+ is_row_parallel = input_size != input_size_per_partition
+ weight_loader = extra_weight_attrs.get("weight_loader")
+
+ # Normalize group_size
+ if self.quant_config.group_size != -1:
+ group_size = self.quant_config.group_size
+ else:
+ group_size = input_size
+
+ # Determine sharding
+ if marlin_repeat_scales_on_all_ranks(
+ self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
+ ):
+ # By setting scale_dim == None, weight_loader will
+ # repeat the scales on each rank in TP>1 case.
+ scales_and_zp_input_dim = None
+ scales_and_zp_size = input_size // group_size
+ else:
+ # By setting scale_dim == 0, weight_loader will
+ # shard the scales in TP>1 case.
+ scales_and_zp_input_dim = 0
+ scales_and_zp_size = input_size_per_partition // group_size
+
+ # Quantized weights
+ qweight = PackedvLLMParameter(
+ data=torch.empty(
+ input_size_per_partition // self.quant_config.pack_factor,
+ output_size_per_partition,
+ dtype=torch.int32,
+ ),
+ input_dim=0,
+ output_dim=1,
+ packed_dim=0,
+ packed_factor=self.quant_config.pack_factor,
+ weight_loader=weight_loader,
+ )
+
+ # Activation order
+ g_idx = RowvLLMParameter(
+ data=torch.empty(
+ input_size_per_partition,
+ dtype=torch.int32,
+ ),
+ input_dim=0,
+ weight_loader=weight_loader,
+ )
+ set_weight_attrs(
+ g_idx,
+ {"ignore_warning": True},
+ )
+
+ qzeros_args = {
+ "data": torch.empty(
+ scales_and_zp_size,
+ output_size_per_partition // self.quant_config.pack_factor,
+ dtype=torch.int32,
+ ),
+ "weight_loader": weight_loader,
+ }
+ weight_scale_args = {
+ "data": torch.empty(
+ scales_and_zp_size,
+ output_size_per_partition,
+ dtype=params_dtype,
+ ),
+ "weight_loader": weight_loader,
+ }
+
+ if scales_and_zp_input_dim is None:
+ scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
+ qzeros = PackedColumnParameter(
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=self.quant_config.pack_factor,
+ **qzeros_args,
+ )
+
+ else:
+ scales = GroupQuantScaleParameter(
+ output_dim=1, input_dim=0, **weight_scale_args
+ )
+ qzeros = PackedvLLMParameter(
+ input_dim=0,
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=self.quant_config.pack_factor,
+ **qzeros_args,
+ )
+
+ layer.register_parameter("qweight", qweight)
+ layer.register_parameter("g_idx", g_idx)
+ layer.register_parameter("scales", scales)
+ layer.register_parameter("qzeros", qzeros)
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
+ packed_weight = layer.qweight.data
+ bits = self.quant_config.weight_bits
+ pack_factor = int(self.quant_config.pack_factor)
+ p_w_k, p_w_n = packed_weight.size()
+ input_size = p_w_k * pack_factor
+ output_size = p_w_n
+ isa_hint = _get_isa_hint(layer.scales.dtype)
+ layer.isa_hint = isa_hint
+
+ layer.qzeros = None
+ if not self.quant_config.desc_act:
+ layer.g_idx = None
+
+ # convert input dim packed to output dim packed
+ weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view(
+ p_w_k, p_w_n, pack_factor
+ )
+ weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
+ weight = pack_cols(weight, bits, input_size, output_size)
+ # make 16 output channel as a block and transpose to the make
+ # the block contigous
+ weight = (
+ weight.view(input_size, -1, 16 // pack_factor)
+ .permute(1, 0, 2)
+ .reshape(-1, input_size * 16 // pack_factor)
+ .contiguous()
+ )
+ layer.qweight.data = weight
+
+ def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = cpu_gemm_wna16(
+ input=x,
+ q_weight=layer.qweight,
+ scales=layer.scales,
+ zeros=layer.qzeros,
+ g_idx=layer.g_idx,
+ bias=bias,
+ pack_factor=8,
+ isa_hint=layer.isa_hint,
+ )
+ return x
+
+
+class CPUAWQConfig(QuantizationConfig):
+ """Config class for CPU AWQ"""
+
+ def __init__(
+ self,
+ weight_bits: int,
+ group_size: int,
+ zero_point: bool,
+ lm_head_quantized: bool,
+ modules_to_not_convert: list[str] | None,
+ full_config: dict[str, Any],
+ ) -> None:
+ super().__init__()
+ assert weight_bits == 4
+ self.pack_factor = 32 // weight_bits # packed into int32
+ self.group_size = group_size
+ self.zero_point = zero_point
+ self.lm_head_quantized = lm_head_quantized
+ self.weight_bits = weight_bits
+ self.modules_to_not_convert = modules_to_not_convert or []
+ self.full_config = full_config
+
+ def __repr__(self) -> str:
+ return (
+ f"AWQMarlinConfig("
+ f"group_size={self.group_size}, "
+ f"zero_point={self.zero_point}, "
+ f"lm_head_quantized={self.lm_head_quantized}, "
+ f"modules_to_not_convert={self.modules_to_not_convert})"
+ )
+
+ @classmethod
+ def get_name(cls) -> "QuantizationMethods":
+ return "cpu_awq"
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
+ return [torch.half, torch.bfloat16]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return -1
+
+ @classmethod
+ def get_config_filenames(cls) -> list[str]:
+ return ["quantize_config.json"]
+
+ @classmethod
+ def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig":
+ weight_bits = cls.get_from_keys(config, ["bits"])
+ group_size = cls.get_from_keys(config, ["group_size"])
+ zero_point = cls.get_from_keys(config, ["zero_point"])
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
+ modules_to_not_convert = cls.get_from_keys_or(
+ config, ["modules_to_not_convert"], None
+ )
+ return cls(
+ weight_bits,
+ group_size,
+ zero_point,
+ lm_head_quantized,
+ modules_to_not_convert,
+ config,
+ )
+
+ @classmethod
+ def override_quantization_method(
+ cls, hf_quant_cfg, user_quant
+ ) -> Optional["QuantizationMethods"]:
+ quant_method = hf_quant_cfg.get("quant_method", "").lower()
+ if current_platform.is_cpu() and (quant_method == "awq"):
+ return cls.get_name()
+ return None
+
+ def get_quant_method(
+ self, layer: torch.nn.Module, prefix: str
+ ) -> Optional["QuantizeMethodBase"]:
+ if isinstance(layer, LinearBase) or (
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
+ ):
+ if is_layer_skipped(
+ prefix,
+ self.modules_to_not_convert,
+ self.packed_modules_mapping,
+ skip_with_substr=True,
+ ):
+ return UnquantizedLinearMethod()
+ return CPUAWQLinearMethod(self)
+ return None
+
+ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
+ if self.modules_to_not_convert:
+ self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
+ self.modules_to_not_convert
+ )
+
+ def maybe_update_config(self, model_name: str, revision: str | None = None):
+ if self.modules_to_not_convert:
+ return
+
+ unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ metadata = get_safetensors_params_metadata(model_name, revision=revision)
+ layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
+ quant_layers: set[str] = {
+ param_name.rsplit(".", 1)[0]
+ for param_name, info in metadata.items()
+ if (dtype := info.get("dtype", None))
+ and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
+ }
+ self.modules_to_not_convert = list(layers - quant_layers)
+
+
+class CPUAWQLinearMethod(LinearMethodBase):
+ """Linear method for CPU AWQ.
+
+ Args:
+ quant_config: The CPU AWQ quantization config.
+ """
+
+ def __init__(self, quant_config: CPUAWQConfig) -> None:
+ self.quant_config = quant_config
+ assert self.quant_config.zero_point
+
+ def create_weights(
+ self,
+ layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: list[int],
+ input_size: int,
+ output_size: int,
+ params_dtype: torch.dtype,
+ **extra_weight_attrs,
+ ) -> None:
+ del output_size
+ output_size_per_partition = sum(output_partition_sizes)
+ weight_loader = extra_weight_attrs.get("weight_loader")
+
+ # Normalize group_size
+ if self.quant_config.group_size != -1:
+ group_size = self.quant_config.group_size
+ else:
+ group_size = input_size
+
+ qweight = PackedvLLMParameter(
+ data=torch.empty(
+ input_size_per_partition,
+ output_size_per_partition // self.quant_config.pack_factor,
+ dtype=torch.int32,
+ ),
+ input_dim=0,
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=self.quant_config.pack_factor,
+ weight_loader=weight_loader,
+ )
+
+ num_groups = input_size_per_partition // group_size
+
+ qzeros = PackedvLLMParameter(
+ data=torch.empty(
+ num_groups,
+ output_size_per_partition // self.quant_config.pack_factor,
+ dtype=torch.int32,
+ ),
+ input_dim=0,
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=self.quant_config.pack_factor,
+ weight_loader=weight_loader,
+ )
+
+ scales = GroupQuantScaleParameter(
+ data=torch.empty(
+ num_groups,
+ output_size_per_partition,
+ dtype=params_dtype,
+ ),
+ input_dim=0,
+ output_dim=1,
+ weight_loader=weight_loader,
+ )
+
+ layer.register_parameter("qweight", qweight)
+ layer.register_parameter("qzeros", qzeros)
+ layer.register_parameter("scales", scales)
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
+ packed_weight = layer.qweight.data
+ packed_zeros = layer.qzeros.data
+ group_num = packed_zeros.size(0)
+ bits = self.quant_config.weight_bits
+ pack_factor = int(self.quant_config.pack_factor)
+ input_size, packed_output_size = packed_weight.size()
+ output_size = packed_output_size * pack_factor
+ isa_hint = _get_isa_hint(layer.scales.dtype)
+ layer.isa_hint = isa_hint
+
+ interleave_map = (0, 4, 1, 5, 2, 6, 3, 7)
+ weight = unpack_cols(
+ packed_weight,
+ bits,
+ input_size,
+ output_size,
+ )
+ zeros = unpack_cols(
+ packed_zeros,
+ bits,
+ group_num,
+ output_size,
+ )
+ weight = (
+ weight.view(input_size, -1, pack_factor)[:, :, interleave_map]
+ .reshape(input_size, output_size)
+ .contiguous()
+ )
+ zeros = (
+ zeros.view(group_num, -1, pack_factor)[:, :, interleave_map]
+ .reshape(group_num, output_size)
+ .contiguous()
+ )
+
+ zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
+ # make 16 output channel as a block and transpose to
+ # the make the block contigous
+ weight = pack_cols(weight, bits, input_size, output_size)
+ weight = (
+ weight.view(input_size, -1, 16 // pack_factor)
+ .permute(1, 0, 2)
+ .reshape(-1, input_size * 16 // pack_factor)
+ .contiguous()
+ )
+ layer.qweight.data = weight
+ layer.qzeros.data = zeros
+
+ def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = cpu_gemm_wna16(
+ input=x,
+ q_weight=layer.qweight,
+ scales=layer.scales,
+ zeros=layer.qzeros,
+ g_idx=None,
+ bias=bias,
+ pack_factor=8,
+ isa_hint=layer.isa_hint,
+ )
+ return x
+
+
+def _get_isa_hint(dtype: torch.dtype) -> str:
+ supports_amx = torch._C._cpu._is_amx_tile_supported()
+ if supports_amx and dtype in (torch.bfloat16,):
+ return "amx"
+ else:
+ return "vec"
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index caabcd0ca0ee5..42d7a67371ae8 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Callable
+from collections.abc import Callable, Mapping
+from types import MappingProxyType
from typing import Any, Optional
import gguf
@@ -26,7 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
-from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ UnquantizedEmbeddingMethod,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
@@ -65,18 +70,70 @@ class GGUFConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
- if is_layer_skipped_gguf(prefix, self.unquantized_modules):
+ if is_layer_skipped_gguf(
+ prefix, self.unquantized_modules, self.packed_modules_mapping
+ ):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
+ if is_layer_skipped_gguf(
+ prefix, self.unquantized_modules, self.packed_modules_mapping
+ ):
+ return UnquantizedEmbeddingMethod()
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None
+ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
+ """
+ Interface for models to update module names referenced in
+ quantization configs in order to reflect the vllm model structure
-def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
- return any(module_name in prefix for module_name in unquantized_modules)
+ :param hf_to_vllm_mapper: maps from hf model structure (the assumed
+ structure of the qconfig) to vllm model structure
+ """
+ if self.unquantized_modules is not None:
+ self.unquantized_modules = hf_to_vllm_mapper.apply_list(
+ self.unquantized_modules
+ )
+
+
+def is_layer_skipped_gguf(
+ prefix: str,
+ unquantized_modules: list[str],
+ fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
+):
+ # Fused layers like gate_up_proj or qkv_proj will not be fused
+ # in the safetensors checkpoint. So, we convert the name
+ # from the fused version to unfused + check to make sure that
+ # each shard of the fused layer has the same scheme.
+ proj_name = prefix.split(".")[-1]
+ if proj_name in fused_mapping:
+ shard_prefixes = [
+ prefix.replace(proj_name, shard_proj_name)
+ for shard_proj_name in fused_mapping[proj_name]
+ ]
+
+ is_skipped = None
+ for shard_prefix in shard_prefixes:
+ is_shard_skipped = any(
+ shard_prefix in module_name for module_name in unquantized_modules
+ )
+
+ if is_skipped is None:
+ is_skipped = is_shard_skipped
+ elif is_shard_skipped != is_skipped:
+ raise ValueError(
+ f"Detected some but not all shards of {prefix} "
+ "are quantized. All shards of fused layers "
+ "to have the same precision."
+ )
+ else:
+ is_skipped = any(module_name in prefix for module_name in unquantized_modules)
+
+ assert is_skipped is not None
+ return is_skipped
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py
index 5ca9167faec80..22c4bae041a56 100644
--- a/vllm/model_executor/layers/quantization/ipex_quant.py
+++ b/vllm/model_executor/layers/quantization/ipex_quant.py
@@ -134,7 +134,7 @@ class IPEXConfig(QuantizationConfig):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
- if not current_platform.is_cpu() and not current_platform.is_xpu():
+ if not current_platform.is_xpu():
return None
quant_method = hf_quant_cfg.get("quant_method", "").lower()
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index e14753c60c485..476521813f464 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
@@ -354,12 +355,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
- if (
- envs.VLLM_USE_FLASHINFER_MOE_FP8
- and has_flashinfer_moe()
- and self.moe.is_act_and_mul
- ):
+ if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
+ if (
+ self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ and not self.moe.is_act_and_mul
+ ):
+ logger.info_once(
+ "Non-gated MoE is not supported for min-latency mode,"
+ "falling back to high-throughput mode"
+ )
+ self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
+
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
@@ -557,10 +564,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if self.flashinfer_moe_backend is not None:
- layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
- register_moe_scaling_factors(layer)
+ if self.moe.is_act_and_mul:
+ layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
+ register_moe_scaling_factors(layer)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -570,13 +578,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
- g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
+ g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
w2_scale=layer.w2_weight_scale,
- g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
+ g2_alphas=layer.output2_scales_scalar.squeeze(),
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
- a2_gscale=1.0 / layer.w2_input_scale,
+ a2_gscale=layer.w2_input_scale_inv,
per_act_token_quant=False,
)
@@ -642,9 +650,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
- assert not renormalize
- assert activation == "silu", (
- f"Expected 'silu' activation but got {activation}"
+ assert activation in ("silu", "relu2_no_mul"), (
+ "Expected activation to be in ('silu', 'relu2_no_mul'),"
+ f"but got {activation}"
)
return flashinfer_cutlass_moe_fp8(
x,
@@ -1650,16 +1658,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
use_llama4_routing = (
custom_routing_function is Llama4MoE.custom_routing_function
)
- routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
+ routing_method_type = layer.routing_method_type
if use_llama4_routing:
- routing_method_type = flashinfer.RoutingMethodType.Llama4
+ routing_method_type = RoutingMethodType.Llama4
+ router_logits = (
+ router_logits.to(torch.float32)
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits
+ )
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
- routing_logits=router_logits
- if use_llama4_routing
- else router_logits.to(torch.float32),
+ routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
@@ -1683,8 +1694,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
- n_group=num_expert_group if num_expert_group is not None else 0,
- topk_group=topk_group if topk_group is not None else 0,
+ n_group=num_expert_group,
+ topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index 5552c1ae5edf8..b95d1a6b3a1f5 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -755,8 +755,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w13_weight = w13_weight
self.w2_weight = w2_weight
- layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
- layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
+ layer.w13_weight = Parameter(w13_weight.storage.data, requires_grad=False)
+ layer.w2_weight = Parameter(w2_weight.storage.data, requires_grad=False)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py
index 1bb698faf46df..f59e5e2a0af7a 100644
--- a/vllm/model_executor/layers/quantization/quark/quark.py
+++ b/vllm/model_executor/layers/quantization/quark/quark.py
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare,
should_ignore_layer,
)
+from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
if TYPE_CHECKING:
@@ -57,7 +58,6 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
- self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", []))
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@@ -72,14 +72,42 @@ class QuarkConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "quark"
+ def apply_vllm_mapper( # noqa: B027
+ self, hf_to_vllm_mapper: "WeightsMapper"
+ ):
+ """
+ Interface for models to update module names referenced in
+ quantization configs in order to reflect the vllm model structure
+
+ :param hf_to_vllm_mapper: maps from hf model structure (the assumed
+ structure of the qconfig) to vllm model structure
+ """
+ quant_config_with_hf_to_vllm_mapper = {}
+
+ for k, v in self.quant_config.items():
+ if isinstance(v, list):
+ quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v)
+ elif isinstance(v, dict):
+ quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v)
+ else:
+ if isinstance(v, str):
+ mapped_v_list = hf_to_vllm_mapper.apply_list([v])
+ if mapped_v_list:
+ quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0]
+ else:
+ quant_config_with_hf_to_vllm_mapper[k] = v
+
+ self.quant_config = quant_config_with_hf_to_vllm_mapper
+
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
+ exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(
- prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
+ prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
@@ -93,9 +121,6 @@ class QuarkConfig(QuantizationConfig):
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
return None
- def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
- self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
-
@classmethod
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export")
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index d9e9b42402712..f22e17945d1f6 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
# TODO(shuw@nvidia): Update when new backends are added.
- backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
+ backends_supporting_global_sf = (
+ FlashinferMoeBackend.CUTLASS,
+ FlashinferMoeBackend.TENSORRT_LLM,
+ )
return backend in backends_supporting_global_sf
diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
index 34a31bcf6a747..cbc46810a26a6 100644
--- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
@@ -8,6 +8,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import triton
+from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
+ assert has_triton_kernels()
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py
index ce4f40680b0a3..4114b21168cc8 100644
--- a/vllm/model_executor/layers/rotary_embedding/base.py
+++ b/vllm/model_executor/layers/rotary_embedding/base.py
@@ -83,6 +83,11 @@ class RotaryEmbeddingBase(CustomOp):
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
+ def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
+ cos_sin = self.cos_sin_cache[:seqlen]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ return cos, sin
+
class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(
diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py
index c06ac550a94ae..b80026741781f 100644
--- a/vllm/model_executor/model_loader/default_loader.py
+++ b/vllm/model_executor/model_loader/default_loader.py
@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
+ get_quant_config,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
- if model_config.quantization == "torchao" and torchao_version_at_least(
- "0.14.0"
- ):
- self.load_config.safetensors_load_strategy = "torchao"
+ if model_config.quantization == "torchao":
+ quant_config = get_quant_config(model_config, self.load_config)
+ if (
+ hasattr(quant_config, "is_checkpoint_torchao_serialized")
+ and quant_config.is_checkpoint_torchao_serialized
+ and torchao_version_at_least("0.14.0")
+ ):
+ self.load_config.safetensors_load_strategy = "torchao"
+
weights_to_load = {name for name, _ in model.named_parameters()}
-
- # if we don't have `model.weight_metadata_and_attr_saved` defined and
- # set to True, it means that this is either offline quantization case
- # or the first run of online quantization
- # see online_quantization.py for detailed notes
- offline_quantization_or_first_run_of_online_quantization = not getattr(
- model, "weight_metadata_and_attr_saved", False
- )
-
- if model_config.quantization is None:
- # model is not quantized
- loaded_weights = model.load_weights(
- self.get_all_weights(model_config, model)
- )
- elif offline_quantization_or_first_run_of_online_quantization:
- # case 1: offline quantized checkpoint
- # case 2: Step I1 first run of weight loading with
- # online quantization
- # see online_quantization.py for detailed notes
- loaded_weights = model.load_weights(
- self.get_all_weights(model_config, model)
- )
- else:
- # to avoid circular dependency
- from vllm.model_executor.model_loader.online_quantization import (
- load_weights_and_online_quantize,
- )
-
- # subsequent runs of weight loading with online
- # quantization
- loaded_weights = load_weights_and_online_quantize(self, model, model_config)
+ loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py
index 7db1fc167c4fa..2416836be03c4 100644
--- a/vllm/model_executor/model_loader/gguf_loader.py
+++ b/vllm/model_executor/model_loader/gguf_loader.py
@@ -7,10 +7,11 @@ import gguf
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
-from transformers import AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
+from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
@@ -21,8 +22,11 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
+from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
from vllm.utils.torch_utils import set_default_torch_dtype
+logger = init_logger(__name__)
+
class GGUFModelLoader(BaseModelLoader):
"""
@@ -67,7 +71,15 @@ class GGUFModelLoader(BaseModelLoader):
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
+ # Get text config to handle both nested (multimodal) and flat
+ # (text-only) config structures. For multimodal models like
+ # Gemma3Config, this returns config.text_config. For text-only
+ # models, this returns config itself.
+ text_config = config.get_text_config()
model_type = config.model_type
+ is_multimodal = (
+ hasattr(config, "vision_config") and config.vision_config is not None
+ )
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
@@ -115,24 +127,167 @@ class GGUFModelLoader(BaseModelLoader):
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
- num_layers = config.num_hidden_layers
- name_map = gguf.get_tensor_name_map(arch, num_layers)
+ text_num_layers = text_config.num_hidden_layers
+ text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)
+
+ if is_multimodal:
+ mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
+ vision_num_layers = config.vision_config.num_hidden_layers
+ vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
+ else:
+ vision_name_map = None
+
+ # Create dummy model to extract parameter names
+ # For multimodal: use AutoModelForImageTextToText to get
+ # language + vision + projector params
+ # For text-only: use AutoModelForCausalLM to get language model params
+ auto_cls = (
+ AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
+ )
with torch.device("meta"):
- dummy_model = AutoModelForCausalLM.from_config(
+ dummy_model = auto_cls.from_config(
config, trust_remote_code=model_config.trust_remote_code
)
- state_dict = dummy_model.state_dict()
+ state_dict = dummy_model.state_dict()
+ if hf_checkpoint_map := getattr(
+ dummy_model, "_checkpoint_conversion_mapping", None
+ ):
+
+ def revert_hf_rename(name: str) -> str:
+ for original_name, hf_name in hf_checkpoint_map.items():
+ if hf_name in name:
+ name = name.replace(hf_name, original_name).lstrip("^")
+ return name
+
+ state_dict = {
+ revert_hf_rename(name): tensor for name, tensor in state_dict.items()
+ }
+
+ def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
+ """
+ Map HuggingFace parameter name to GGUF tensor name.
+
+ This function handles the mismatch between HF parameter naming
+ conventions and gguf-py's expected format:
+ 1. Strips 'model.' prefix (common in multimodal models)
+ 2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
+ 3. Searches vision_name_map for multimodal parameters
+ 4. Falls back to text_name_map for language model parameters
+
+ Args:
+ hf_name: Full HuggingFace parameter name (e.g.,
+ 'model.multi_modal_projector.mm_soft_emb_norm.weight')
+
+ Returns:
+ GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
+ or None if no mapping found
+ """
+ # Strip 'language_model.' prefix for multimodal models - gguf-py
+ # tensor mappings expect parameter names without this prefix.
+ # Note: 'model.' prefix should be KEPT for text-only models as
+ # gguf-py expects it.
+ if hf_name.startswith("language_model."):
+ hf_name = hf_name[15:] # Remove 'language_model.'
+
+ # Parse parameter name and suffix
+ if hf_name.endswith((".weight", ".bias")):
+ base_name, suffix = hf_name.rsplit(".", 1)
+ else:
+ base_name, suffix = hf_name, ""
+ # Handle '_weight' suffix (Gemma3 naming: parameter ends with
+ # '_weight' instead of '.weight')
+ if base_name.endswith("_weight"):
+ base_name = base_name[:-7] # Remove '_weight'
+ suffix = "weight"
+
+ gguf_name = None
+ # Priority 1: Search vision/projector parameters for multimodal models
+ if vision_name_map is not None:
+ gguf_name = vision_name_map.get_name(base_name)
+
+ # Priority 2: Search text backbone parameters
+ if gguf_name is None:
+ gguf_name = text_name_map.get_name(base_name)
+
+ if gguf_name is None:
+ return None
+
+ return gguf_name + "." + suffix
+
+ # Build mapping and track unmapped parameters
+ unmapped_params = []
for hf_name in state_dict:
- name, suffix = hf_name.rsplit(".", 1)
- gguf_name = name_map.get_name(name)
- gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
+ gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
+
+ # Track mapping success
+ if gguf_name_with_suffix is not None:
+ gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
+ logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
+ elif hf_name not in gguf_to_hf_name_map.values():
+ # Parameter not in manual overrides either
+ unmapped_params.append(hf_name)
+
+ # All parameters must be mapped: both vision/projector and backbone
+ if unmapped_params:
+ raise RuntimeError(
+ f"Failed to map GGUF parameters "
+ f"({len(unmapped_params)}): "
+ f"{unmapped_params}"
+ )
return gguf_to_hf_name_map
+ def _get_gguf_weight_type(
+ self,
+ model_config: ModelConfig,
+ model_name_or_path: str,
+ gguf_to_hf_name_map: dict[str, str],
+ ) -> dict[str, str]:
+ weight_type_map = get_gguf_weight_type_map(
+ model_config.model, gguf_to_hf_name_map
+ )
+ is_multimodal = hasattr(model_config.hf_config, "vision_config")
+ if is_multimodal:
+ mmproj_file = detect_gguf_multimodal(model_name_or_path)
+ assert mmproj_file is not None, (
+ "Could not find mm_proj file for multimodal GGUF model"
+ )
+ logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
+ mm_proj_weight_type_map = get_gguf_weight_type_map(
+ mmproj_file, gguf_to_hf_name_map
+ )
+ weight_type_map.update(mm_proj_weight_type_map)
+ return weight_type_map
+
def _get_weights_iterator(
- self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
+ self,
+ model_config: ModelConfig,
+ model_name_or_path: str,
+ gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, torch.Tensor], None, None]:
- return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
+ """
+ Iterate over GGUF model weights, loading from both main model file and
+ mmproj.gguf for multimodal Gemma3 models.
+
+ For Gemma3 multimodal GGUF models:
+ - Main file (gemma-3-*.gguf): Language model weights (model.*)
+ - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)
+
+ Yields:
+ Tuples of (parameter_name, tensor) for all model weights
+ """
+ hf_config = model_config.hf_config
+ is_multimodal = hasattr(hf_config, "vision_config")
+
+ if is_multimodal:
+ # Load mm_proj (mm_encoder + projector) for multimodal weights
+ mmproj_file = detect_gguf_multimodal(model_name_or_path)
+ assert mmproj_file is not None, (
+ "Could not find mm_proj file for multimodal GGUF model"
+ )
+ yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
+
+ yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
@@ -141,7 +296,7 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
- self._get_weights_iterator(local_model_path, gguf_weights_map)
+ self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
)
def load_model(
@@ -156,14 +311,19 @@ class GGUFModelLoader(BaseModelLoader):
):
model_config.hf_config.update({"tie_word_embeddings": True})
- weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map)
-
+ weight_type_map = self._get_gguf_weight_type(
+ model_config, local_model_path, gguf_weights_map
+ )
# filter out unquantized modules to skip
unquant_names = [
name.removesuffix(".weight")
for name, weight_type in weight_type_map.items()
- if weight_type == "F32" and name.endswith(".weight")
+ if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
]
+ logger.debug(
+ "GGUF unquantized modules: %s",
+ unquant_names,
+ )
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)
diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py
index 890dd7231a0e1..f330af85bbe8b 100644
--- a/vllm/model_executor/model_loader/online_quantization.py
+++ b/vllm/model_executor/model_loader/online_quantization.py
@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
+from collections.abc import Iterable
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
-from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import process_weights_after_loading
logger = init_logger(__name__)
@@ -56,6 +56,9 @@ logger = init_logger(__name__)
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
+# R5. (workaround for cudagraph), we restore the weight params to original quantized
+# weights params, and use original_weight_param.copy_(updated_weight_param) so that
+# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
@@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
if model_config.quantization != "torchao":
return
- if getattr(model, "process_weights_after_loading_already_called", False):
- # In case `process_weights_after_loading` is called multiple times
- # we'll skip it at later times
- logger.warning(
- "process_weights_after_loading already called for model %s", model
- )
- return
-
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, None)
@@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
+ model._model_config = model_config
model.weight_metadata_and_attr_saved = True
@@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj):
return types.MethodType(func, obj)
-def load_weights_and_online_quantize(
- model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
-) -> set[str]:
+def support_quantized_model_reload_from_hp_weights(original_load_weights):
+ """Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
+ reloading high precision (bfloat16/float16/float32) weight for an already quantized
+ model, this involves restoring the weights to a high precision weights and
+ then online quantize the weights
+ """
# online quantization, right now only enabled for
# torchao
- # R1, R2, R3, R4 in the Notes
+ # R1, R2, R3, R4, R5 in the Notes
- # TODO: Add fp8 support
- assert model_config.quantization == "torchao", (
- "online quantization is only enabled for torchao currently"
- )
- # TODO: use create_weights to restore the weights to original state
+ def patched_model_load_weights(
+ auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
+ ) -> set[str]:
+ model = auto_weight_loader.module
+ offline_quantization_or_first_run_of_online_quantization = not getattr(
+ model, "weight_metadata_and_attr_saved", False
+ )
- # Step R1: First restore the quantized weights to original bfloat16
- # weights, with original metadata (shape, dtype, device)
- # and attributes, so that bfloat16 weights can be loaded properly
- existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
- named_modules = dict(model.named_modules(remove_duplicate=False))
- model_device = None
+ # if we don't have `model.weight_metadata_and_attr_saved` defined and
+ # set to True, it means that this is either offline quantization case
+ # or the first run of online quantization
+ # see Notes in this file for more details
+ if offline_quantization_or_first_run_of_online_quantization:
+ # case 1: offline quantized checkpoint
+ # case 2: Step I1 first run of weight loading with
+ # online quantization
+ return original_load_weights(auto_weight_loader, weights, mapper=mapper)
- # Step R2: recover the parameter to the state before first loading
- for name, d in model.original_weights_rebuild_keys.items():
- _shape = d["shape"]
- _dtype = d["dtype"]
- _device = d["device"]
+ model_config = model._model_config
+
+ # TODO: Add fp8 support
+ assert model_config.quantization == "torchao", (
+ "online quantization is only enabled for torchao currently"
+ )
+ # TODO: use create_weights to restore the weights to original state
+
+ # Step R1: First restore the quantized weights to original bfloat16
+ # weights, with original metadata (shape, dtype, device)
+ # and attributes, so that bfloat16 weights can be loaded properly
+ # TODO: maybe set remove_duplicate to True?
+ original_quantized_weight_dict = dict(
+ model.named_parameters(remove_duplicate=False)
+ )
+ named_modules = dict(model.named_modules(remove_duplicate=False))
+ model_device = None
+
+ for name, d in model.original_weights_rebuild_keys.items():
+ _shape = d["shape"]
+ _dtype = d["dtype"]
+ _device = d["device"]
+ if model_device is not None:
+ assert model_device == _device, (
+ "Expecting all weights "
+ "to be in the same device for now, got both: "
+ f"{model_device} and {_device}"
+ )
+ else:
+ model_device = _device
+
+ if name in original_quantized_weight_dict:
+ module_name, weight_name = name.rsplit(".", 1)
+ module = named_modules[module_name]
+ setattr(
+ module,
+ weight_name,
+ torch.nn.Parameter(
+ torch.empty(_shape, dtype=_dtype, device=_device),
+ requires_grad=False,
+ ),
+ )
+
+ # Step R2: recover the weight attributes to the state before first loading
+ # recorded_weight_attr is
+ # {"weight_name": {"weight_attr_key": attr}}
+ # e.g.
+ # {
+ # {
+ # "layer.0.weight": {
+ # "weight_loader": weight_loader_function_object,
+ # "input_dim": 0, ...
+ # },
+ # "layer.1.weight": ...,
+ # }
+ # }
+ for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
+ for attr_name, attr in weight_attr_dict.items():
+ module_name, weight_name = full_weight_name.rsplit(".", 1)
+ module = named_modules[module_name]
+ weight = getattr(module, weight_name)
+ if not hasattr(weight, attr_name):
+ setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
+
+ # Step R3: reload bfloat16 / high precision weights
+ updated_params = original_load_weights(
+ auto_weight_loader, weights, mapper=mapper
+ )
+
+ # Step R4: online quantize the weights
+ # manually process weights after loading
+ model.process_weights_after_loading_already_called = False
if model_device is not None:
- assert model_device == _device, (
- "Expecting all weights "
- "to be in the same device for now, got both: "
- f"{model_device} and {_device}"
- )
+ process_weights_after_loading(model, model_config, model_device)
else:
- model_device = _device
-
- if name in existing_param_names:
- module_name, weight_name = name.rsplit(".", 1)
- module = named_modules[module_name]
- setattr(
- module,
- weight_name,
- torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
+ logger.warning_once(
+ "model_device is None, skip calling process_weights_after_loading"
)
- # recorded_weight_attr is
- # {"weight_name": {"weight_attr_key": attr}}
- # e.g.
- # {
- # {
- # "layer.0.weight": {
- # "weight_loader": weight_loader_function_object,
- # "input_dim": 0, ...
- # },
- # "layer.1.weight": ...,
- # }
- # }
- for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
- for attr_name, attr in weight_attr_dict.items():
- module_name, weight_name = full_weight_name.rsplit(".", 1)
- module = named_modules[module_name]
- weight = getattr(module, weight_name)
- if not hasattr(weight, attr_name):
- setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
+ # Step R5 (workaround for cudagraph): restore the original quantized weights
+ # and do a copy_ of the currents weights to the original weights
+ updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
+ for name in model.original_weights_rebuild_keys:
+ if name in original_quantized_weight_dict:
+ original_quantized_weight = original_quantized_weight_dict[name]
+ updated_quantized_weight = updated_quantized_weights[name]
- # Step I1: reload bfloat16 / high precision weights
- loaded_weights = model.load_weights(
- model_loader.get_all_weights(model_config, model)
- )
+ module_name, weight_name = name.rsplit(".", 1)
+ module = named_modules[module_name]
+ setattr(module, weight_name, original_quantized_weight)
+ with torch.no_grad():
+ original_quantized_weight.copy_(updated_quantized_weight)
- # Step I2: online quantize the weights
- # manually process weights after loading
- model.process_weights_after_loading_already_called = False
- process_weights_after_loading(model, model_config, model_device)
- model.process_weights_after_loading_already_called = True
- return loaded_weights
+ del original_quantized_weight_dict
+ del named_modules
+ del updated_quantized_weight
+
+ model.process_weights_after_loading_already_called = True
+ return updated_params
+
+ return patched_model_load_weights
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index ba708a098c0da..e74434e9d12cb 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -88,6 +88,14 @@ def initialize_model(
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
+ if getattr(model, "process_weights_after_loading_already_called", False):
+ # In case `process_weights_after_loading` is called multiple times
+ # we'll skip it at later times
+ logger.debug_once(
+ "process_weights_after_loading already called for model %s", model
+ )
+ return
+
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 93986e5f2fc0a..89634cbf41241 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -836,7 +836,11 @@ def gguf_quant_weights_iterator(
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
- them to torch tensors
+ them to torch tensors.
+ Be careful of the order of yielding weight types and weights data,
+ we have to yield all weight types first before yielding any weights.
+ Otherwise it would cause issue when loading weights with for packed
+ layer with different quant types.
"""
reader = gguf.GGUFReader(gguf_file)
@@ -846,7 +850,7 @@ def gguf_quant_weights_iterator(
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
- if weight_type.name != "F32":
+ if weight_type.name not in ("F32", "BF16", "F16"):
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
@@ -856,7 +860,7 @@ def gguf_quant_weights_iterator(
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
- if weight_type.name != "F32":
+ if weight_type.name not in ("F32", "BF16", "F16"):
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param
diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py
new file mode 100644
index 0000000000000..6f654f47495f7
--- /dev/null
+++ b/vllm/model_executor/models/afmoe.py
@@ -0,0 +1,711 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only AfMoE model compatible with HuggingFace weights."""
+
+import typing
+from collections.abc import Callable, Iterable
+from itertools import islice
+from typing import Any
+
+import torch
+from torch import nn
+
+from vllm.attention import Attention, AttentionType
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_world_size,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
+from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ WeightsMapper,
+ extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+from vllm.sequence import IntermediateTensors
+
+logger = init_logger(__name__)
+
+
+class AfmoeMoE(nn.Module):
+ def __init__(
+ self,
+ config, # AfmoeConfig
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.route_scale = config.route_scale
+ self.score_func = config.score_func
+ self.route_norm = config.route_norm
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.num_experts
+ self.n_shared_experts: int = config.num_shared_experts
+
+ if config.hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now."
+ )
+
+ # Router gate
+ self.gate = nn.Linear(
+ config.hidden_size,
+ config.num_experts,
+ bias=False,
+ dtype=torch.float32,
+ )
+ self.expert_bias = nn.Parameter(
+ torch.empty(config.num_experts, dtype=torch.float32)
+ )
+
+ # Load balancing settings
+ vllm_config = get_current_vllm_config()
+ eplb_config = vllm_config.parallel_config.eplb_config
+ self.enable_eplb = enable_eplb
+
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = (
+ self.physical_expert_start + self.n_local_physical_experts
+ )
+
+ self.shared_experts = None
+ # Shared experts
+ if config.num_shared_experts > 0:
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
+ self.shared_experts = AfmoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ # Routed experts using SharedFusedMoE
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ num_experts=config.num_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=self.route_norm if self.score_func == "sigmoid" else False,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func=self.score_func,
+ routed_scaling_factor=self.route_scale,
+ e_score_correction_bias=self.expert_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ router_logits = self.gate(hidden_states.to(dtype=torch.float32))
+
+ fused_moe_out = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+
+ if self.shared_experts is not None:
+ shared_output, final_hidden_states = fused_moe_out
+ final_hidden_states = final_hidden_states + shared_output
+ else:
+ final_hidden_states = fused_moe_out
+ if self.tp_size > 1:
+ final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states
+ )
+
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class AfmoeAttention(nn.Module):
+ def __init__(
+ self,
+ config, # AfmoeConfig
+ layer_idx: int,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rope_theta: float = 10000,
+ rope_scaling: dict[str, Any] | None = None,
+ max_position_embeddings: int = 131072,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-05,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ attn_type: str = AttentionType.DECODER,
+ ) -> None:
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ # Check if this is a local attention layer
+ self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
+ self.sliding_window = config.sliding_window if self.is_local_attention else None
+
+ self.qkv_proj = QKVParallelLinear(
+ self.hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ # Gating projection
+ self.gate_proj = ColumnParallelLinear(
+ hidden_size,
+ self.total_num_heads * self.head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_proj",
+ )
+
+ # Q/K normalization
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ # Only create rotary embeddings for local attention
+ if self.is_local_attention:
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ is_neox_style=True,
+ )
+ else:
+ self.rotary_emb = None
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ per_layer_sliding_window=self.sliding_window,
+ prefix=f"{prefix}.attn",
+ attn_type=attn_type,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ gate, _ = self.gate_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ # Apply Q/K normalization
+ q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape)
+ k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
+ k.shape
+ )
+
+ # Apply rotary embeddings only for local attention
+ if self.is_local_attention and self.rotary_emb is not None:
+ q, k = self.rotary_emb(positions, q, k)
+
+ attn_output = self.attn(q, k, v)
+
+ # Apply gating
+ attn_output = attn_output * torch.sigmoid(gate)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class AfmoeDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config, # AfmoeConfig
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is not None and getattr(
+ config, "original_max_position_embeddings", None
+ ):
+ rope_scaling["original_max_position_embeddings"] = (
+ config.original_max_position_embeddings
+ )
+ max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
+
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ self.layer_idx = extract_layer_index(prefix)
+
+ self.self_attn = AfmoeAttention(
+ config=config,
+ layer_idx=self.layer_idx,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ head_dim=config.head_dim,
+ rms_norm_eps=config.rms_norm_eps,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ # MoE or dense FFN
+ self.moe_enabled = self.layer_idx >= config.num_dense_layers
+ if self.moe_enabled:
+ self.mlp = AfmoeMoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ enable_eplb=enable_eplb,
+ )
+ else:
+ self.mlp = AfmoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states) # attn norm b
+
+ # Fully Connected
+ hidden_states, residual = self.pre_mlp_layernorm( # ffn norm a
+ hidden_states, residual
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_mlp_layernorm(hidden_states) # ffn norm b
+
+ return hidden_states, residual
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
+class AfmoeModel(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ enable_eplb = vllm_config.parallel_config.enable_eplb
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+ self.mup_enabled = config.mup_enabled
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens"
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: AfmoeDecoderLayer(
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ enable_eplb=enable_eplb,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+
+ # Apply muP input scaling if enabled
+ if self.mup_enabled:
+ hidden_states = hidden_states * (self.config.hidden_size**0.5)
+
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype, device: torch.device
+ ) -> IntermediateTensors:
+ return IntermediateTensors(
+ {
+ "hidden_states": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ "residual": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ }
+ )
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ return SharedFusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ expert_params_mapping = self.get_expert_mapping()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if (weight_name not in name) or ("self_attn.gate_proj" in name):
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
+
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_suffix={
+ ".router.gate.weight": ".gate.weight",
+ },
+ )
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = AfmoeModel(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size, config.hidden_size, quant_config=quant_config
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+ self.expert_weights = []
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers
+ self.num_expert_groups = config.n_group
+
+ self.moe_layers: list[SharedFusedMoE] = []
+ example_moe = None
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
+ assert isinstance(layer, AfmoeDecoderLayer)
+ if layer.moe_enabled:
+ example_moe = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ if example_moe is None and self.num_moe_layers > 0:
+ raise RuntimeError("No AfmoeMoE layer found in model.layers.")
+
+ if example_moe is not None:
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ # Register the expert weights.
+ self.expert_weights.append(layer.get_expert_weights())
+ layer.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py
index 5872e8196eada..3d000f3ac3ab5 100644
--- a/vllm/model_executor/models/aimv2.py
+++ b/vllm/model_executor/models/aimv2.py
@@ -12,6 +12,7 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -58,7 +59,7 @@ class AIMv2SwiGLUFFN(nn.Module):
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py
index 6e1e5b1ddc509..024425bb24406 100644
--- a/vllm/model_executor/models/bailing_moe.py
+++ b/vllm/model_executor/models/bailing_moe.py
@@ -599,7 +599,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
- prefix=f"{prefix}.lm_head",
+ prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
else:
diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py
index e0a2defd5127e..c6cc83487fec2 100644
--- a/vllm/model_executor/models/bamba.py
+++ b/vllm/model_executor/models/bamba.py
@@ -138,8 +138,7 @@ class BambaMixerDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
- output = torch.empty_like(hidden_states)
- self.mamba(hidden_states, output)
+ output = self.mamba(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py
index 2e4f73312efa3..f31f99c0592b2 100644
--- a/vllm/model_executor/models/blip.py
+++ b/vllm/model_executor/models/blip.py
@@ -12,6 +12,7 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -47,7 +48,7 @@ class BlipVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index fb7476c45fcdb..3c87bbfefab3d 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -22,6 +22,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -549,7 +550,7 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
- self.conv = nn.Conv2d(
+ self.conv = Conv2dLayer(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
@@ -577,23 +578,23 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
- self.conv1 = torch.nn.Conv2d(
+ self.conv1 = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.dropout = torch.nn.Dropout(config.dropout)
- self.conv2 = torch.nn.Conv2d(
+ self.conv2 = Conv2dLayer(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
+ self.conv_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
- self.nin_shortcut = torch.nn.Conv2d(
+ self.nin_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
@@ -626,16 +627,16 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
- self.q = torch.nn.Conv2d(
+ self.q = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
- self.k = torch.nn.Conv2d(
+ self.k = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
- self.v = torch.nn.Conv2d(
+ self.v = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
- self.proj_out = torch.nn.Conv2d(
+ self.proj_out = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
@@ -681,7 +682,7 @@ class ChameleonVQVAEEncoder(nn.Module):
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
- self.conv_in = torch.nn.Conv2d(
+ self.conv_in = Conv2dLayer(
in_channels, base_channels, kernel_size=3, stride=1, padding=1
)
@@ -738,7 +739,7 @@ class ChameleonVQVAEEncoder(nn.Module):
self.norm_out = torch.nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
- self.conv_out = torch.nn.Conv2d(
+ self.conv_out = Conv2dLayer(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
@@ -779,10 +780,8 @@ class ChameleonVQVAE(nn.Module):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
- self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(
- config.embed_dim, config.latent_channels, 1
- )
+ self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
+ self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(
diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py
index e62a57eccc953..8f1660891fcbf 100644
--- a/vllm/model_executor/models/deepencoder.py
+++ b/vllm/model_executor/models/deepencoder.py
@@ -19,6 +19,7 @@ import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -133,14 +134,14 @@ class ImageEncoderViT(nn.Module):
self.blocks.append(block)
self.neck = nn.Sequential(
- nn.Conv2d(
+ Conv2dLayer(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
- nn.Conv2d(
+ Conv2dLayer(
out_chans,
out_chans,
kernel_size=3,
@@ -150,8 +151,10 @@ class ImageEncoderViT(nn.Module):
LayerNorm2d(out_chans),
)
- self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
- self.net_3 = nn.Conv2d(
+ self.net_2 = Conv2dLayer(
+ 256, 512, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.net_3 = Conv2dLayer(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
@@ -500,7 +503,7 @@ class PatchEmbed(nn.Module):
"""
super().__init__()
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py
index 9e834a73f8e5e..3fb04c3b70dd1 100644
--- a/vllm/model_executor/models/deepseek_eagle.py
+++ b/vllm/model_executor/models/deepseek_eagle.py
@@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import (
)
from vllm.utils import init_logger
-from .utils import AutoWeightsLoader, maybe_prefix
+from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
+ process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 115818d903a6d..e8ee9951d6119 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
-from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
+from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
class DeepseekV2ForCausalLM(
- nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
+ nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py
index 405af8f8be426..2d2251e83b5b1 100644
--- a/vllm/model_executor/models/dots_ocr.py
+++ b/vllm/model_executor/models/dots_ocr.py
@@ -22,6 +22,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -39,8 +40,8 @@ from vllm.model_executor.models.interfaces import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
-from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
from vllm.model_executor.models.qwen2_vl import (
+ Qwen2VisionAttention,
Qwen2VLDummyInputsBuilder,
Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo,
@@ -328,7 +329,7 @@ class DotsVisionAttention(nn.Module):
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
x, _ = self.qkv(x)
- q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
+ q, k, v = Qwen2VisionAttention.split_qkv(self, x)
bs = q.shape[1]
# [S,B,H,D] -> [B,S,H,D]
q = q.permute(1, 0, 2, 3).contiguous()
@@ -471,7 +472,7 @@ class DotsPatchEmbed(nn.Module):
self.temporal_patch_size = config.temporal_patch_size
self.embed_dim = config.embed_dim
self.config = config
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
config.num_channels,
config.embed_dim,
kernel_size=(config.patch_size, config.patch_size),
diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py
index 3653425b8e1ca..b985847af5daf 100644
--- a/vllm/model_executor/models/falcon_h1.py
+++ b/vllm/model_executor/models/falcon_h1.py
@@ -198,10 +198,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
residual: torch.Tensor | None,
**kwargs,
):
- output = torch.empty_like(hidden_states)
- self.mamba(
+ output = self.mamba(
hidden_states,
- output,
mup_vector=self.mup_vector,
)
return output, residual
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index 02fb7ef31dc94..fe83c8b63b018 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -596,13 +596,33 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
- def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = True,
+ ) -> torch.Tensor:
+ # Early return for text-only inference (no multimodal data)
+ if multimodal_embeddings is None or is_multimodal is None:
+ return super().embed_input_ids(input_ids)
+
+ # Use interface default with OOV handling enabled
+ return super().embed_input_ids(
+ input_ids,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ handle_oov_mm_token=handle_oov_mm_token,
+ )
+
def forward(
self,
input_ids: torch.Tensor,
@@ -624,6 +644,79 @@ class Gemma3ForConditionalGeneration(
return hidden_states
+ def generate_attention_masks(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ mask_dtype: torch.dtype,
+ ) -> dict[str, Any]:
+ """Generate custom attention masks for Gemma3 multimodal inputs.
+
+ This is called by V1 engine's gpu_model_runner during preprocessing
+ to generate attention masks that allow bidirectional attention between
+ image tokens while maintaining causal attention for text.
+ """
+ # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
+ # This is a HACK. Fix this.
+ start_indices = (positions == 0).cpu().nonzero()
+ num_seqs = len(start_indices)
+ seq_lens = []
+ for i in range(num_seqs):
+ start_idx = start_indices[i]
+ end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
+ seq_lens.append(end_idx - start_idx)
+
+ global_attn_masks = []
+ local_attn_masks = []
+ start_idx = 0
+ for seq_idx, seq_len in enumerate(seq_lens):
+ end_idx = start_idx + seq_len
+ input_token_ids = input_ids[start_idx:end_idx]
+
+ # Find image token positions
+ img_pos = input_token_ids == self.config.image_token_index
+
+ start_idx = end_idx
+
+ # Create a global causal mask
+ global_attn_mask = torch.empty(
+ 1,
+ 1,
+ seq_len,
+ seq_len,
+ dtype=mask_dtype,
+ device=input_ids.device,
+ )
+ global_attn_mask.fill_(float("-inf"))
+ # Fill the lower triangle with 0 (causal attention)
+ global_attn_mask = global_attn_mask.triu(diagonal=1)
+
+ # Enable bidirectional attention between image tokens
+ img_mask = torch.zeros_like(global_attn_mask)
+ img_mask[:, :, :, img_pos] += 1
+ img_mask[:, :, img_pos, :] += 1
+ global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
+ global_attn_masks.append(global_attn_mask)
+
+ # GGUF compatibility: config might be Gemma3TextConfig directly
+ text_config = getattr(self.config, "text_config", self.config)
+ sliding_window = text_config.sliding_window
+ if sliding_window is not None:
+ # Create a local causal mask with sliding window (1024)
+ local_attn_mask = torch.ones_like(global_attn_mask)
+ local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
+ local_attn_mask = torch.where(
+ local_attn_mask == 0, global_attn_mask, float("-inf")
+ )
+ local_attn_masks.append(local_attn_mask)
+
+ return {
+ "has_images": True,
+ "seq_lens": seq_lens,
+ "global_attn_masks": global_attn_masks,
+ "local_attn_masks": local_attn_masks,
+ }
+
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py
index 6953b805653b4..2c2f45c2453ee 100644
--- a/vllm/model_executor/models/glm4_1v.py
+++ b/vllm/model_executor/models/glm4_1v.py
@@ -56,7 +56,7 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
-from vllm.model_executor.layers.conv import Conv3dLayer
+from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -65,6 +65,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -341,7 +342,8 @@ class Glm4vVisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
@@ -353,10 +355,12 @@ class Glm4vVisionAttention(nn.Module):
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
- if rotary_pos_emb is not None:
+ if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
- qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
+ qk_rotated = apply_rotary_pos_emb_vision(
+ qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
+ )
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
@@ -454,14 +458,16 @@ class Glm4vVisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -660,44 +666,6 @@ class Glm4vVisionEmbeddings(nn.Module):
return embeddings
-class Glm4vVisionRotaryEmbedding(nn.Module):
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
- super().__init__()
- self.dim = dim
- self.theta = theta
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self._seq_len_cached = 0
- self._freqs_cached = None
-
- def update_freqs_cache(self, seqlen: int) -> None:
- if seqlen > self._seq_len_cached:
- seqlen *= 2
- self._seq_len_cached = seqlen
- self.inv_freq = 1.0 / (
- self.theta
- ** (
- torch.arange(
- 0,
- self.dim,
- 2,
- dtype=torch.float,
- device=self.inv_freq.device,
- )
- / self.dim
- )
- )
- seq = torch.arange(
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
- )
- freqs = torch.outer(seq, self.inv_freq)
- self._freqs_cached = freqs
-
- def forward(self, seqlen: int) -> torch.Tensor:
- self.update_freqs_cache(seqlen)
- return self._freqs_cached[:seqlen]
-
-
class Glm4vVisionTransformer(nn.Module):
def __init__(
self,
@@ -731,7 +699,13 @@ class Glm4vVisionTransformer(nn.Module):
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
- self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
+ self.rotary_pos_emb = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim // 2,
+ max_position=8192,
+ base=10000.0,
+ is_neox_style=True,
+ )
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
@@ -760,7 +734,7 @@ class Glm4vVisionTransformer(nn.Module):
self.post_conv_layernorm = RMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
- self.downsample = nn.Conv2d(
+ self.downsample = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,
@@ -789,7 +763,9 @@ class Glm4vVisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
- def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ def rot_pos_emb(
+ self, grid_thw: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@@ -817,9 +793,18 @@ class Glm4vVisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- return rotary_pos_emb, pos_ids
+
+ # Use pre-computed cos_sin_cache from RotaryEmbedding
+ cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
+
+ cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
+ cos_w = cos[pos_ids[:, 1]]
+ sin_h = sin[pos_ids[:, 0]]
+ sin_w = sin[pos_ids[:, 1]]
+
+ cos_combined = torch.cat([cos_h, cos_w], dim=-1)
+ sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ return cos_combined, sin_combined, pos_ids
def compute_attn_mask_seqlen(
self,
@@ -848,7 +833,9 @@ class Glm4vVisionTransformer(nn.Module):
x = self.post_conv_layernorm(x)
# compute position embedding
- rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
+ rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
+ grid_thw
+ )
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -867,7 +854,8 @@ class Glm4vVisionTransformer(nn.Module):
x = blk(
x,
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py
index 110ed0a646334..e34ae6c85a4f8 100644
--- a/vllm/model_executor/models/glm4_moe_mtp.py
+++ b/vllm/model_executor/models/glm4_moe_mtp.py
@@ -256,13 +256,12 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
- spec_layer = self.model.mtp_start_layer_idx
for name, loaded_weight in weights:
if name == "lm_head.weight":
- name = f"model.layers.{spec_layer}.shard_head.head.weight"
+ spec_layer = self.model.mtp_start_layer_idx
+ name = f"model.layers.{spec_layer}.shared_head.head.weight"
elif name == "model.embed_tokens.weight":
- # This name is same with local model, rewriting is not needed.
- pass
+ spec_layer = self.model.mtp_start_layer_idx
else:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py
index 1c18ea0745f2b..514082cf60ce2 100644
--- a/vllm/model_executor/models/glm4v.py
+++ b/vllm/model_executor/models/glm4v.py
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -78,7 +79,7 @@ class GLMVImagePixelInputs(TensorSchema):
class EVA2CLIPPatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
@@ -333,7 +334,7 @@ class EVA2CLIPModel(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.linear_proj",
)
- self.conv = nn.Conv2d(
+ self.conv = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index 692ef605fe175..7df3b087ccb88 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
def _load_weights_other(
self,
- ep_rank_start: int,
ep_rank_end: int,
+ ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
@@ -641,8 +641,8 @@ class GptOssModel(nn.Module):
)
else:
return self._load_weights_other(
- ep_rank_end,
ep_rank_start,
+ ep_rank_end,
heads_per_rank,
head_start,
weights,
diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py
index 05177f1d1ac2c..a340112ec62ae 100644
--- a/vllm/model_executor/models/granitemoehybrid.py
+++ b/vllm/model_executor/models/granitemoehybrid.py
@@ -115,8 +115,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
- output = torch.empty_like(hidden_states)
- self.mamba(hidden_states, output)
+ output = self.mamba(hidden_states)
hidden_states = residual + output * self.residual_multiplier
residual = hidden_states
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index 727c8ec0397ca..06b8468e18db9 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -30,6 +30,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -60,7 +61,7 @@ class Idefics2VisionEmbeddings(nn.Module):
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index 929bfaaee5cbb..dc4caf2f02f9d 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -932,13 +932,73 @@ def supports_transcription(
@runtime_checkable
-class SupportsEagle3(Protocol):
+class SupportsEagleBase(Protocol):
+ """Base interface for models that support EAGLE-based speculative decoding."""
+
+ has_own_lm_head: bool = False
+ """
+ A flag that indicates this model has trained its own lm_head.
+ """
+
+ has_own_embed_tokens: bool = False
+ """
+ A flag that indicates this model has trained its own input embeddings.
+ """
+
+
+@overload
+def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
+
+
+@overload
+def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
+
+
+def supports_any_eagle(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
+ """Check if model supports any EAGLE variant (1, 2, or 3)."""
+ return supports_eagle(model) or supports_eagle3(model)
+
+
+@runtime_checkable
+class SupportsEagle(SupportsEagleBase, Protocol):
"""The interface required for models that support
- EAGLE3 speculative decoding."""
+ EAGLE-1 and EAGLE-2 speculative decoding."""
+
+ supports_eagle: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports EAGLE-1 and EAGLE-2
+ speculative decoding.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ MRO of your model class.
+ """
+
+
+@overload
+def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
+
+
+@overload
+def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
+
+
+def supports_eagle(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
+ return isinstance(model, SupportsEagle)
+
+
+@runtime_checkable
+class SupportsEagle3(SupportsEagleBase, Protocol):
+ """The interface required for models that support
+ EAGLE-3 speculative decoding."""
supports_eagle3: ClassVar[Literal[True]] = True
"""
- A flag that indicates this model supports EAGLE3
+ A flag that indicates this model supports EAGLE-3
speculative decoding.
Note:
@@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
- hidden states for EAGLE3.
+ hidden states for EAGLE-3.
Args:
layers: Tuple of layer indices that should output auxiliary
@@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol):
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
- for EAGLE3.
+ for EAGLE-3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.
diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py
index 03918127c6ae1..61aeafc2ab436 100644
--- a/vllm/model_executor/models/intern_vit.py
+++ b/vllm/model_executor/models/intern_vit.py
@@ -24,6 +24,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -51,7 +52,7 @@ class InternVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py
index 507503d75046d..cb0414bbc95a8 100644
--- a/vllm/model_executor/models/interns1_vit.py
+++ b/vllm/model_executor/models/interns1_vit.py
@@ -16,6 +16,7 @@ from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -43,7 +44,7 @@ class InternS1VisionPatchEmbeddings(nn.Module):
self.num_patches = num_patches
self.patch_shape = patch_shape
- self.projection = nn.Conv2d(
+ self.projection = Conv2dLayer(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 1eb0eccc0411c..8fc3db296aa79 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -204,7 +205,7 @@ class KeyeVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index c49a1ea817f91..0a3f37c30ab5f 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
+from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -529,7 +529,9 @@ class LlamaModel(nn.Module):
return loaded_params
-class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
+class LlamaForCausalLM(
+ nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
+):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py
index e8716d652415e..660c8f1bb5226 100644
--- a/vllm/model_executor/models/llama4_eagle.py
+++ b/vllm/model_executor/models/llama4_eagle.py
@@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa
from vllm.model_executor.models.utils import extract_layer_index
from .interfaces import SupportsMultiModal
-from .utils import AutoWeightsLoader, maybe_prefix
+from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
+ process_eagle_weight(self, name)
return name, weight
loader = AutoWeightsLoader(
diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py
index ab2a9f6f06dbe..90ab5c50361b6 100644
--- a/vllm/model_executor/models/llama_eagle.py
+++ b/vllm/model_executor/models/llama_eagle.py
@@ -11,13 +11,22 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
-from .utils import AutoWeightsLoader, maybe_prefix
+from .utils import (
+ AutoWeightsLoader,
+ get_draft_quant_config,
+ maybe_prefix,
+ process_eagle_weight,
+)
logger = init_logger(__name__)
@@ -40,14 +49,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
- draft_model_config = vllm_config.speculative_config.draft_model_config
- draft_load_config = vllm_config.load_config
-
- return (
- VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
- if draft_model_config
- else None
- )
+ return get_draft_quant_config(vllm_config)
@support_torch_compile
@@ -63,6 +65,9 @@ class LlamaModel(nn.Module):
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
+ # Get drafter's quantization config
+ self.quant_config = get_draft_quant_config(vllm_config)
+
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
@@ -80,8 +85,14 @@ class LlamaModel(nn.Module):
for i in range(self.config.num_hidden_layers)
]
)
- self.fc = torch.nn.Linear(
- self.config.hidden_size * 2, self.config.hidden_size, bias=False
+ self.fc = ReplicatedLinear(
+ input_size=self.config.hidden_size * 2,
+ output_size=self.config.hidden_size,
+ bias=False,
+ params_dtype=vllm_config.model_config.dtype,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "fc"),
+ return_bias=False,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -117,6 +128,24 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
+ # Handle kv cache quantization scales
+ if self.quant_config is not None and (
+ scale_name := self.quant_config.get_cache_scale(name)
+ ):
+ # Loading kv cache quantization scales
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ loaded_weight = (
+ loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ # Remapping the name FP8 kv-scale
+ if "scale" in name:
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -179,6 +208,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
+ process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(
diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py
index 6edc9519dfbbf..75c671311b491 100644
--- a/vllm/model_executor/models/llama_eagle3.py
+++ b/vllm/model_executor/models/llama_eagle3.py
@@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import QKVParallelLinear
+from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
-from .utils import AutoWeightsLoader, maybe_prefix
+from .utils import (
+ AutoWeightsLoader,
+ get_draft_quant_config,
+ maybe_prefix,
+ process_eagle_weight,
+)
logger = init_logger(__name__)
@@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
- draft_model_config = vllm_config.speculative_config.draft_model_config
- draft_load_config = vllm_config.load_config
-
- return (
- VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
- if draft_model_config
- else None
- )
+ return get_draft_quant_config(vllm_config)
def _norm_before_residual(
self, hidden_states: torch.Tensor
@@ -140,6 +141,9 @@ class LlamaModel(nn.Module):
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
+ # Get drafter's quantization config
+ self.quant_config = get_draft_quant_config(vllm_config)
+
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@@ -160,13 +164,19 @@ class LlamaModel(nn.Module):
]
)
if hasattr(self.config, "target_hidden_size"):
- self.fc = torch.nn.Linear(
- self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
- )
+ fc_input_size = self.config.target_hidden_size * 3
else:
- self.fc = torch.nn.Linear(
- self.config.hidden_size * 3, self.config.hidden_size, bias=False
- )
+ fc_input_size = self.config.hidden_size * 3
+ self.fc = ReplicatedLinear(
+ input_size=fc_input_size,
+ output_size=self.config.hidden_size,
+ bias=False,
+ params_dtype=vllm_config.model_config.dtype,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "fc"),
+ return_bias=False,
+ )
+
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@@ -211,6 +221,24 @@ class LlamaModel(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
+ # Handle kv cache quantization scales
+ if self.quant_config is not None and (
+ scale_name := self.quant_config.get_cache_scale(name)
+ ):
+ # Loading kv cache quantization scales
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ loaded_weight = (
+ loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ # Remapping the name FP8 kv-scale
+ if "scale" in name:
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -324,6 +352,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
+ process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:
diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py
index fc17f98be1986..5fcfa94312303 100644
--- a/vllm/model_executor/models/mamba2.py
+++ b/vllm/model_executor/models/mamba2.py
@@ -87,8 +87,7 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
- output = torch.empty_like(hidden_states)
- self.mixer(hidden_states, output)
+ output = self.mixer(hidden_states)
return output, residual
diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py
index a84c99059cd9c..d9b23811730d4 100644
--- a/vllm/model_executor/models/midashenglm.py
+++ b/vllm/model_executor/models/midashenglm.py
@@ -39,6 +39,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -120,7 +121,7 @@ class AudioPatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
in_chans,
embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py
index 0ca31913485db..d0cdb70aa8574 100644
--- a/vllm/model_executor/models/minicpm_eagle.py
+++ b/vllm/model_executor/models/minicpm_eagle.py
@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsLoRA, SupportsPP
+from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
@@ -52,6 +52,7 @@ from .utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
+ process_eagle_weight,
)
@@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module):
return loaded_params
-class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
+class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ def transform(inputs):
+ name, loaded_weight = inputs
+ process_eagle_weight(self, name)
+ return name, loaded_weight
+
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
- return loader.load_weights(weights)
+ return loader.load_weights(map(transform, weights))
diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py
index 8017c947bf9ad..2e3e6dc166ad8 100644
--- a/vllm/model_executor/models/moonvit.py
+++ b/vllm/model_executor/models/moonvit.py
@@ -53,6 +53,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
@@ -244,7 +245,7 @@ class MoonVisionPatchEmbed(nn.Module):
)
self.patch_size = patch_size
- self.proj = nn.Conv2d(
+ self.proj = Conv2dLayer(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)
diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py
index f7e0caf410e10..8675eff592224 100644
--- a/vllm/model_executor/models/nemotron_h.py
+++ b/vllm/model_executor/models/nemotron_h.py
@@ -376,8 +376,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
- output = torch.empty_like(hidden_states)
- self.mixer(hidden_states, output)
+ output = self.mixer(hidden_states)
return output, residual
diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py
index 3ef6470070d18..dee0c16ab0f63 100644
--- a/vllm/model_executor/models/paddleocr_vl.py
+++ b/vllm/model_executor/models/paddleocr_vl.py
@@ -45,6 +45,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -419,7 +420,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 8cb7d6a889da4..8a034fd72b02a 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -31,6 +31,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -747,7 +748,7 @@ class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
- self.patch_conv = nn.Conv2d(
+ self.patch_conv = Conv2dLayer(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
@@ -1212,7 +1213,7 @@ class PixtralHFVisionModel(nn.Module):
self.config = config
- self.patch_conv = nn.Conv2d(
+ self.patch_conv = Conv2dLayer(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 7617929e93ac4..2e4fd9645d88f 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -64,6 +64,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
@@ -359,44 +360,45 @@ class Qwen2_5_VisionAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA,
}
- def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
- # [s, b, 3 * head * head_dim]
- seq_len, bs, _ = qkv.shape
-
- # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
- q, k, v = qkv.chunk(3, dim=2)
-
- # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
- new_shape = (
- seq_len,
- bs,
- self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head,
- )
- q, k, v = (x.view(*new_shape) for x in (q, k, v))
- return q, k, v
-
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
+ seq_len, batch_size, _ = x.shape
- # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
- q, k, v = self.split_qkv(x)
- batch_size = q.shape[1]
+ qkv = einops.rearrange(
+ x,
+ "s b (three head head_dim) -> b s three head head_dim",
+ three=3,
+ head=self.num_attention_heads_per_partition,
+ )
- q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
- if rotary_pos_emb is not None:
- # [2 * b, s, heads, head_dim]
- qk_concat = torch.cat([q, k], dim=0)
- qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
- q, k = torch.chunk(qk_rotated, 2, dim=0)
+ if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
+ qk, v = qkv[:, :, :2], qkv[:, :, 2]
+
+ qk_reshaped = einops.rearrange(
+ qk, "b s two head head_dim -> (two b) s head head_dim", two=2
+ )
+ qk_rotated = apply_rotary_pos_emb_vision(
+ qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
+ )
+ qk_rotated = qk_rotated.view(
+ 2,
+ batch_size,
+ seq_len,
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ q, k = qk_rotated.unbind(dim=0)
+ else:
+ q, k, v = qkv.unbind(dim=2)
if self.is_flash_attn_backend:
context_layer = vit_flash_attn_wrapper(
@@ -436,7 +438,8 @@ class Qwen2_5_VisionAttention(nn.Module):
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
- "rotary_pos_emb": 0,
+ "rotary_pos_emb_cos": 0,
+ "rotary_pos_emb_sin": 0,
"seqlens": 0,
},
mark_unbacked_dims={"seqlens": 0},
@@ -487,14 +490,16 @@ class Qwen2_5_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -590,42 +595,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return out
-class Qwen2_5_VisionRotaryEmbedding(nn.Module):
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
- super().__init__()
- self.dim = dim
- self.theta = theta
- inv_freq = 1.0 / (
- theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim)
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self._seq_len_cached = 0
- self._freqs_cached = None
-
- def update_freqs_cache(self, seqlen: int) -> None:
- if seqlen > self._seq_len_cached:
- seqlen *= 2
- self._seq_len_cached = seqlen
- self.inv_freq = 1.0 / (
- self.theta
- ** (
- torch.arange(
- 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
- )
- / self.dim
- )
- )
- seq = torch.arange(
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
- )
- freqs = torch.outer(seq, self.inv_freq)
- self._freqs_cached = freqs
-
- def forward(self, seqlen: int) -> torch.Tensor:
- self.update_freqs_cache(seqlen)
- return self._freqs_cached[:seqlen]
-
-
class Qwen2_5_VisionTransformer(nn.Module):
def __init__(
self,
@@ -668,7 +637,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
- self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+ self.rotary_pos_emb = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim // 2,
+ max_position=8192,
+ base=10000.0,
+ is_neox_style=True,
+ )
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
@@ -759,15 +734,30 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
max_size = max(h, w)
- rotary_pos_emb_full = self.rotary_pos_emb(max_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- rotary_pos_emb = rotary_pos_emb.reshape(
- rotary_pos_emb.shape[0] // self.spatial_merge_unit,
+
+ # Use pre-computed cos_sin_cache from RotaryEmbedding
+ cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)
+
+ cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
+ cos_w = cos[pos_ids[:, 1]]
+ sin_h = sin[pos_ids[:, 0]]
+ sin_w = sin[pos_ids[:, 1]]
+
+ cos_combined = torch.cat([cos_h, cos_w], dim=-1)
+ sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+
+ cos_combined = cos_combined.reshape(
+ cos_combined.shape[0] // self.spatial_merge_unit,
+ self.spatial_merge_unit,
+ -1,
+ )
+ sin_combined = sin_combined.reshape(
+ sin_combined.shape[0] // self.spatial_merge_unit,
self.spatial_merge_unit,
-1,
)
- return rotary_pos_emb
+ return cos_combined, sin_combined
def get_window_index_thw(self, grid_t, grid_h, grid_w):
vit_merger_window_size = (
@@ -809,14 +799,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
@lru_cache(maxsize=1024) # noqa: B019
def get_rope_by_thw(self, t, h, w):
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w)
- rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
- rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
- rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
+ cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w)
+
+ cos_thw = cos_thw[window_index_thw, :, :]
+ cos_thw = cos_thw.flatten(start_dim=0, end_dim=1)
+ sin_thw = sin_thw[window_index_thw, :, :]
+ sin_thw = sin_thw.flatten(start_dim=0, end_dim=1)
+
cu_seqlens_thw = torch.repeat_interleave(
torch.tensor([h * w], dtype=torch.int32), t
)
return (
- rotary_pos_emb_thw,
+ cos_thw,
+ sin_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
@@ -851,7 +846,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> torch.Tensor:
# patchify
seq_len, _ = x.size()
- rotary_pos_emb = []
+ rotary_pos_emb_cos = []
+ rotary_pos_emb_sin = []
window_index: list = []
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
cu_seqlens: list = []
@@ -867,7 +863,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
llm_w = w // self.spatial_merge_size
(
- rotary_pos_emb_thw,
+ cos_thw,
+ sin_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
@@ -880,11 +877,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
cu_window_seqlens.append(cu_seqlens_window_thw)
- rotary_pos_emb.append(rotary_pos_emb_thw)
+ rotary_pos_emb_cos.append(cos_thw)
+ rotary_pos_emb_sin.append(sin_thw)
cu_seqlens.append(cu_seqlens_thw)
- rotary_pos_emb = torch.cat(rotary_pos_emb)
+ rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos)
+ rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin)
window_index = torch.cat(window_index)
# compute reverse indices
reverse_indices = self.invert_permutation(window_index)
@@ -903,7 +902,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True)
- rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True)
+ rotary_pos_emb_cos = rotary_pos_emb_cos.to(
+ device=self.device, non_blocking=True
+ )
+ rotary_pos_emb_sin = rotary_pos_emb_sin.to(
+ device=self.device, non_blocking=True
+ )
window_index = window_index.to(device=hidden_states.device, non_blocking=True)
reverse_indices = reverse_indices.to(
device=hidden_states.device, non_blocking=True
@@ -930,7 +934,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
)
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 5d21e249fc4cc..53df5972a8fe1 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -32,7 +32,7 @@ from typing import Annotated, Any, Literal, TypeAlias
import torch
import torch.nn as nn
import torch.nn.functional as F
-from einops import rearrange, repeat
+from einops import rearrange
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
@@ -59,7 +59,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
+ apply_rotary_emb_torch,
dispatch_rotary_emb_function,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -275,47 +277,13 @@ class Qwen2VisionMLP(nn.Module):
return x
-def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
- if not interleaved:
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- else:
- x1, x2 = x[..., ::2], x[..., 1::2]
- return rearrange(
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
- )
-
-
-def apply_rotary_emb_torch(
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
+def apply_rotary_pos_emb_vision(
+ t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
- """
- x: (batch_size, seqlen, nheads, headdim)
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
- """
- ro_dim = cos.shape[-1] * 2
- assert ro_dim <= x.shape[-1]
- cos = repeat(
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
+ rotary_emb_function = dispatch_rotary_emb_function(
+ default=partial(apply_rotary_emb_torch, is_neox_style=True)
)
- sin = repeat(
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
- )
- return torch.cat(
- [
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
- x[..., ro_dim:],
- ],
- dim=-1,
- )
-
-
-def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
- rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
- t_ = t.float()
- cos = freqs.cos()
- sin = freqs.sin()
- output = rotary_emb_function(t_, cos, sin).type_as(t)
+ output = rotary_emb_function(t, cos, sin).type_as(t)
return output
@@ -412,7 +380,8 @@ class Qwen2VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
@@ -424,11 +393,13 @@ class Qwen2VisionAttention(nn.Module):
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
- if rotary_pos_emb is not None:
- # [2 * b, s, heads, head_dim]
- qk_concat = torch.cat([q, k], dim=0)
- qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
- q, k = torch.chunk(qk_rotated, 2, dim=0)
+
+ # [2 * b, s, heads, head_dim]
+ qk_concat = torch.cat([q, k], dim=0)
+ qk_rotated = apply_rotary_pos_emb_vision(
+ qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
+ )
+ q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@@ -534,14 +505,16 @@ class Qwen2VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -628,40 +601,6 @@ class Qwen2VisionPatchMerger(nn.Module):
return out
-class Qwen2VisionRotaryEmbedding(nn.Module):
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
- super().__init__()
- self.dim = dim
- self.theta = theta
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self._seq_len_cached = 0
- self._freqs_cached = None
-
- def update_freqs_cache(self, seqlen: int) -> None:
- if seqlen > self._seq_len_cached:
- seqlen *= 2
- self._seq_len_cached = seqlen
- self.inv_freq = 1.0 / (
- self.theta
- ** (
- torch.arange(
- 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
- )
- / self.dim
- )
- )
- seq = torch.arange(
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
- )
- freqs = torch.outer(seq, self.inv_freq)
- self._freqs_cached = freqs
-
- def forward(self, seqlen: int) -> torch.Tensor:
- self.update_freqs_cache(seqlen)
- return self._freqs_cached[:seqlen]
-
-
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
@@ -700,7 +639,13 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
- self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
+ self.rotary_pos_emb = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim // 2,
+ max_position=8192,
+ base=10000.0,
+ is_neox_style=True,
+ )
self.blocks = nn.ModuleList(
[
@@ -744,7 +689,9 @@ class Qwen2VisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
- def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
+ def rot_pos_emb(
+ self, grid_thw: list[list[int]]
+ ) -> tuple[torch.Tensor, torch.Tensor]:
pos_ids = []
max_grid_size = 0
for t, h, w in grid_thw:
@@ -773,9 +720,18 @@ class Qwen2VisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
max_grid_size = max(max_grid_size, h, w)
pos_ids = torch.cat(pos_ids, dim=0)
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- return rotary_pos_emb
+
+ # Use pre-computed cos_sin_cache from RotaryEmbedding
+ cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
+
+ cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
+ cos_w = cos[pos_ids[:, 1]]
+ sin_h = sin[pos_ids[:, 0]]
+ sin_w = sin[pos_ids[:, 1]]
+
+ cos_combined = torch.cat([cos_h, cos_w], dim=-1)
+ sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+ return cos_combined, sin_combined
def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
@@ -806,7 +762,7 @@ class Qwen2VisionTransformer(nn.Module):
grid_thw_list = grid_thw.tolist()
# compute position embedding
- rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
+ rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
@@ -824,7 +780,8 @@ class Qwen2VisionTransformer(nn.Module):
x = blk(
x,
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
index 40b80ce2387c8..8274b92138f78 100755
--- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py
+++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
@@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo
@@ -90,7 +91,6 @@ from .qwen2_5_omni_thinker import (
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
- Qwen2_5_VisionRotaryEmbedding,
Qwen2_5_VLProcessingInfo,
)
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
@@ -221,14 +221,16 @@ class Qwen3_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -332,7 +334,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
- self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+ self.rotary_pos_emb = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim // 2,
+ max_position=8192,
+ base=10000.0,
+ is_neox_style=True,
+ )
self.blocks = nn.ModuleList(
[
@@ -416,9 +424,19 @@ class Qwen3Omni_VisionTransformer(nn.Module):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- return rotary_pos_emb
+
+ # Use pre-computed cos_sin_cache from RotaryEmbedding
+ cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
+
+ cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
+ cos_w = cos[pos_ids[:, 1]]
+ sin_h = sin[pos_ids[:, 0]]
+ sin_w = sin[pos_ids[:, 1]]
+
+ cos_combined = torch.cat([cos_h, cos_w], dim=-1)
+ sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+
+ return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
@@ -508,7 +526,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if self.apply_vit_abs_pos_embed:
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
- rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -519,7 +537,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
- rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+ rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device)
+ rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device)
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
hidden_states_list = []
@@ -529,7 +548,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
index f1c020ab5813c..99a4007ef7f23 100644
--- a/vllm/model_executor/models/qwen3_vl.py
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -24,8 +24,8 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
-from collections.abc import Callable, Iterable, Mapping, Sequence
-from functools import partial
+from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
+from functools import lru_cache, partial
from itertools import islice
from typing import Any
@@ -63,6 +63,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -95,7 +96,6 @@ from .interfaces import (
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
- Qwen2_5_VisionRotaryEmbedding,
Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
@@ -232,14 +232,16 @@ class Qwen3_VisionBlock(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
+ rotary_pos_emb_cos: torch.Tensor,
+ rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -339,7 +341,13 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
- self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+ self.rotary_pos_emb = get_rope(
+ head_size=head_dim,
+ rotary_dim=head_dim // 2,
+ max_position=8192,
+ base=10000.0,
+ is_neox_style=True,
+ )
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
@@ -416,34 +424,55 @@ class Qwen3_VisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
- def rot_pos_emb(self, grid_thw: list[list[int]]):
- pos_ids = []
- max_grid_size = max(max(h, w) for _, h, w in grid_thw)
- for t, h, w in grid_thw:
- hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
- hpos_ids = hpos_ids.reshape(
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
- )
- hpos_ids = hpos_ids.permute(0, 2, 1, 3)
- hpos_ids = hpos_ids.flatten()
+ @staticmethod
+ @lru_cache(maxsize=1024)
+ def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
+ hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
+ h_div = h // spatial_merge_size
+ w_div = w // spatial_merge_size
+ hpos_ids = hpos_ids.reshape(
+ h_div,
+ spatial_merge_size,
+ w_div,
+ spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
- wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
- wpos_ids = wpos_ids.reshape(
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
- )
- wpos_ids = wpos_ids.permute(0, 2, 1, 3)
- wpos_ids = wpos_ids.flatten()
- pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
+ wpos_ids = wpos_ids.reshape(
+ h_div,
+ spatial_merge_size,
+ w_div,
+ spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+
+ return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
+
+ def rot_pos_emb(self, grid_thw: list[list[int]]):
+ max_grid_size = max(max(h, w) for _, h, w in grid_thw)
+ pos_ids = [
+ self.rot_pos_ids(h, w, self.spatial_merge_size)
+ if t == 1
+ else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
+ for t, h, w in grid_thw
+ ]
pos_ids = torch.cat(pos_ids, dim=0)
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- return rotary_pos_emb
+
+ # Use pre-computed cos_sin_cache from RotaryEmbedding
+ cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
+
+ cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
+ cos_w = cos[pos_ids[:, 1]]
+ sin_h = sin[pos_ids[:, 0]]
+ sin_w = sin[pos_ids[:, 1]]
+
+ cos_combined = torch.cat([cos_h, cos_w], dim=-1)
+ sin_combined = torch.cat([sin_h, sin_w], dim=-1)
+
+ return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
@@ -536,8 +565,13 @@ class Qwen3_VisionTransformer(nn.Module):
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
- rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
- rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
+ rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
+ rotary_pos_emb_cos = rotary_pos_emb_cos.to(
+ hidden_states.device, non_blocking=True
+ )
+ rotary_pos_emb_sin = rotary_pos_emb_sin.to(
+ hidden_states.device, non_blocking=True
+ )
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -553,7 +587,8 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
- rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
@@ -1412,72 +1447,47 @@ class Qwen3VLForConditionalGeneration(
)
return mm_input_by_modality
+ def iter_mm_grid_hw(
+ self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
+ ) -> Iterator[tuple[int, int, int]]:
+ video_token_id = self.config.video_token_id
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
+ offset = mm_feature.mm_position.offset
+ if mm_feature.modality == "image":
+ t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
+ assert t == 1, f"Image must have 1 frame, got {t}"
+ yield offset, h // spatial_merge_size, w // spatial_merge_size
+ elif mm_feature.modality == "video":
+ t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
+ llm_grid_h = h // spatial_merge_size
+ llm_grid_w = w // spatial_merge_size
+ for _ in range(t):
+ offset = input_tokens.index(video_token_id, offset)
+ yield offset, llm_grid_h, llm_grid_w
+ offset += llm_grid_h * llm_grid_w
+ else:
+ raise ValueError(f"Unsupported modality: {mm_feature.modality}")
+
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
- kwargs = MultiModalFeatureSpec.gather_kwargs(
- mm_features,
- {"image_grid_thw", "video_grid_thw"},
- )
- image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
- video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
-
- video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)]
-
- hf_config = self.config
- image_token_id = hf_config.image_token_id
- video_token_id = hf_config.video_token_id
- vision_start_token_id = hf_config.vision_start_token_id
- spatial_merge_size = hf_config.vision_config.spatial_merge_size
-
- input_tokens_array = np.array(input_tokens)
- vision_start_mask = input_tokens_array == vision_start_token_id
- vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
- image_nums = np.count_nonzero(vision_tokens == image_token_id)
- video_nums = np.count_nonzero(vision_tokens == video_token_id)
- llm_pos_ids_list: list = []
-
+ llm_pos_ids_list = []
st = 0
- remain_images, remain_videos = image_nums, video_nums
-
- image_index, video_index = 0, 0
- for _ in range(image_nums + video_nums):
- if image_token_id in input_tokens and remain_images > 0:
- ed_image = input_tokens.index(image_token_id, st)
- else:
- ed_image = len(input_tokens) + 1
- if video_token_id in input_tokens and remain_videos > 0:
- ed_video = input_tokens.index(video_token_id, st)
- else:
- ed_video = len(input_tokens) + 1
- if ed_image < ed_video:
- t, h, w = image_grid_thw[image_index]
- image_index += 1
- remain_images -= 1
- ed = ed_image
- else:
- t, h, w = video_grid_thw[video_index]
- video_index += 1
- remain_videos -= 1
- ed = ed_video
-
- llm_grid_t, llm_grid_h, llm_grid_w = (
- t,
- h // spatial_merge_size,
- w // spatial_merge_size,
- )
- text_len = ed - st
-
+ for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
+ input_tokens, mm_features
+ ):
+ text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
- grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
- llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+ grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
+ llm_pos_ids_list.append(grid_indices + text_len + st_idx)
+ st = offset + llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py
index 6a259cade9cf1..4906cf441f6fb 100644
--- a/vllm/model_executor/models/qwen_vl.py
+++ b/vllm/model_executor/models/qwen_vl.py
@@ -25,6 +25,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
@@ -333,7 +334,7 @@ class VisionTransformer(nn.Module):
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
- self.conv1 = nn.Conv2d(
+ self.conv1 = Conv2dLayer(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index c50be12883897..8211321c39537 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -56,6 +56,7 @@ logger = init_logger(__name__)
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
+ "AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
"ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
@@ -597,7 +598,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
mi_dict = json.load(file)
except FileNotFoundError:
logger.debug(
- ("Cached model info file for class %s.%s not found"),
+ "Cached model info file for class %s.%s not found",
self.module_name,
self.class_name,
)
@@ -605,7 +606,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
if mi_dict["hash"] != module_hash:
logger.debug(
- ("Cached model info file for class %s.%s is stale"),
+ "Cached model info file for class %s.%s is stale",
self.module_name,
self.class_name,
)
@@ -615,7 +616,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
return _ModelInfo(**mi_dict["modelinfo"])
except Exception:
logger.debug(
- ("Cached model info for class %s.%s error. "),
+ "Cached model info for class %s.%s error. ",
self.module_name,
self.class_name,
)
@@ -650,14 +651,14 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
mi = self._load_modelinfo_from_cache(module_hash)
if mi is not None:
logger.debug(
- ("Loaded model info for class %s.%s from cache"),
+ "Loaded model info for class %s.%s from cache",
self.module_name,
self.class_name,
)
return mi
else:
logger.debug(
- ("Cache model info for class %s.%s miss. Loading model instead."),
+ "Cache model info for class %s.%s miss. Loading model instead.",
self.module_name,
self.class_name,
)
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
index b175dd60cf650..ce5847bf79a5e 100644
--- a/vllm/model_executor/models/siglip.py
+++ b/vllm/model_executor/models/siglip.py
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -286,7 +287,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
@@ -827,6 +828,7 @@ class SiglipVisionModel(nn.Module):
) -> None:
super().__init__()
+ self.quant_config = quant_config
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
@@ -911,12 +913,38 @@ class SiglipVisionModel(nn.Module):
break
else:
param = params_dict[name]
+ param = maybe_swap_ffn_param(
+ name, param, loaded_weight, params_dict, self.quant_config
+ )
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
+def maybe_swap_ffn_param(
+ name: str,
+ param: torch.Tensor,
+ loaded_weight: torch.Tensor,
+ params_dict: dict[str, torch.Tensor],
+ quant_config: QuantizationConfig,
+) -> torch.Tensor:
+ if not (quant_config and quant_config.get_name() == "gguf") or ".fc" not in name:
+ return param
+ # Some GGUF models have fc1 and fc2 weights swapped
+ tp_size = get_tensor_model_parallel_world_size()
+ output_dim = getattr(param, "output_dim", 0)
+ output_size = param.size(output_dim) * tp_size
+ weight_out_size = loaded_weight.size(output_dim)
+ if ".fc1." in name and output_size != weight_out_size:
+ new_name = name.replace(".fc1.", ".fc2.")
+ param = params_dict[new_name]
+ elif ".fc2." in name and output_size != weight_out_size:
+ new_name = name.replace(".fc2.", ".fc1.")
+ param = params_dict[new_name]
+ return param
+
+
# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py
index 29dd164ad37fd..46f5e67d659ef 100644
--- a/vllm/model_executor/models/siglip2navit.py
+++ b/vllm/model_executor/models/siglip2navit.py
@@ -16,6 +16,7 @@ from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
@@ -67,7 +68,7 @@ class Siglip2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
else:
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
@@ -99,7 +100,7 @@ class Siglip2VisionEmbeddings(nn.Module):
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, LinearBase):
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
- elif isinstance(self.patch_embedding, nn.Conv2d):
+ elif isinstance(self.patch_embedding, Conv2dLayer):
pixel_values = pixel_values.view(
-1,
self.config.num_channels * self.config.temporal_patch_size,
diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py
index 5d16be1eb3128..1c60cb4148121 100644
--- a/vllm/model_executor/models/step3_vl.py
+++ b/vllm/model_executor/models/step3_vl.py
@@ -20,6 +20,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -667,7 +668,7 @@ class Step3VisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
- self.patch_embedding = nn.Conv2d(
+ self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
@@ -950,13 +951,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
- self.vit_downsampler = nn.Conv2d(
+ self.vit_downsampler = Conv2dLayer(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride,
)
- self.vit_downsampler2 = nn.Conv2d(
+ self.vit_downsampler2 = Conv2dLayer(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index e5663c8a057ac..ccefd7e66697f 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -18,7 +18,14 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig,
+)
+from vllm.model_executor.model_loader.online_quantization import (
+ support_quantized_model_reload_from_hp_weights,
+)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@@ -312,6 +319,7 @@ class AutoWeightsLoader:
)
raise ValueError(msg)
+ @support_quantized_model_reload_from_hp_weights
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
@@ -714,6 +722,30 @@ def maybe_prefix(prefix: str, name: str) -> str:
return name if not prefix else f"{prefix}.{name}"
+def get_draft_quant_config(
+ vllm_config: VllmConfig,
+) -> QuantizationConfig | None:
+ """Get quantization config for Draft models.
+
+ Draft models should use their own quantization config instead of the verifier/target
+ model's config. This helper retrieves the draft model's quantization config.
+
+ Args:
+ vllm_config: The vLLM configuration object.
+
+ Returns:
+ The draft model's config if available, None otherwise.
+ """
+ draft_model_config = vllm_config.speculative_config.draft_model_config
+ draft_load_config = vllm_config.load_config
+
+ return (
+ VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
+ if draft_model_config
+ else None
+ )
+
+
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
"""
Extract the layer index from the module name.
@@ -825,3 +857,25 @@ direct_register_custom_op(
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
+
+
+def process_eagle_weight(
+ model: nn.Module,
+ name: str,
+) -> None:
+ """
+ Update EAGLE model flags based on loaded weight name.
+ This should be called during weight loading to detect if a model
+ has its own lm_head or embed_tokens weight.
+ Args:
+ model: The model instance (must support EAGLE)
+ name: The name of the weight to process
+ """
+ if not supports_any_eagle(model):
+ return
+
+ # To prevent overriding with target model's layers
+ if "lm_head" in name:
+ model.has_own_lm_head = True
+ if "embed_tokens" in name:
+ model.has_own_embed_tokens = True
diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py
index 64e6979c8fcfb..729a9655d0879 100644
--- a/vllm/model_executor/models/zamba2.py
+++ b/vllm/model_executor/models/zamba2.py
@@ -567,11 +567,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Process through Mamba mixer
- output = torch.empty_like(hidden_states)
- self.mamba(
- hidden_states,
- output,
- )
+ output = self.mamba(hidden_states)
# residual connection after mamba
hidden_states = residual + output
diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py
index 2fa3f6ebcc114..810f29072a0fe 100644
--- a/vllm/multimodal/parse.py
+++ b/vllm/multimodal/parse.py
@@ -359,8 +359,9 @@ class MultiModalDataParser:
)
self.video_needs_metadata = video_needs_metadata
- def _is_embeddings(
- self, data: object
+ @classmethod
+ def is_embeddings(
+ cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
@@ -420,7 +421,7 @@ class MultiModalDataParser:
):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
@@ -458,7 +459,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return ImageEmbeddingItems(data)
if (
@@ -484,7 +485,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index 1da34629472c7..ed655912d3964 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -339,7 +339,7 @@ class CpuPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
- vllm_config.scheduler_config.max_model_len,
+ vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index 788f9d69c357a..bb116792fed54 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -185,6 +185,9 @@ class RocmPlatform(Platform):
"petit_nvfp4",
"torchao",
]
+ # bitsandbytes not supported on gfx9 (warp size 64 limitation)
+ if not on_gfx9():
+ supported_quantization += ["bitsandbytes"]
@classmethod
def get_vit_attn_backend(
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index c1218801bc077..944344a229578 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -191,7 +191,7 @@ class TpuPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
- vllm_config.scheduler_config.max_model_len,
+ vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index ad4beb28bdae0..65516827a16da 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -185,7 +185,7 @@ class XPUPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
- vllm_config.scheduler_config.max_model_len,
+ vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py
index 72a8320cc1bf8..5c3dfa8ac9cbc 100644
--- a/vllm/pooling_params.py
+++ b/vllm/pooling_params.py
@@ -57,6 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
+ skip_reading_prefix_cache: bool = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@@ -93,6 +94,8 @@ class PoolingParams(
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
+ if self.skip_reading_prefix_cache is None:
+ self.skip_reading_prefix_cache = True
return
# NOTE: Task validation needs to done against the model instance,
@@ -122,6 +125,15 @@ class PoolingParams(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
+ if self.skip_reading_prefix_cache is None:
+ # If prefix caching is enabled,
+ # the output of all pooling may less than n_prompt_tokens,
+ # we need to skip reading cache at this request.
+ if self.task in ["token_embed", "token_classify"]:
+ self.skip_reading_prefix_cache = True
+ else:
+ self.skip_reading_prefix_cache = False
+
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index dd820840410ed..0fb1d67687c82 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -204,6 +204,12 @@ class SamplingParams(
prompt_logprobs: int | None = None
"""Number of log probabilities to return per prompt token.
When set to -1, return all `vocab_size` log probabilities."""
+ flat_logprobs: bool = False
+ """Whether to return logprobs in flatten format (i.e. FlatLogprob)
+ for better performance.
+ NOTE: GC costs of FlatLogprobs is significantly smaller than
+ list[dict[int, Logprob]]. After enabled, PromptLogprobs and
+ SampleLogprobs would populated as FlatLogprobs."""
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
@@ -254,6 +260,8 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
+ skip_reading_prefix_cache: bool = None
+
@staticmethod
def from_optional(
n: int | None = 1,
@@ -414,6 +422,12 @@ class SamplingParams(
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
+ if self.skip_reading_prefix_cache is None:
+ # If prefix caching is enabled,
+ # the output of prompt logprobs may less than n_prompt_tokens,
+ # we need to skip reading cache at this request.
+ self.skip_reading_prefix_cache = self.prompt_logprobs is not None
+
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 6bcc94ad5c625..6d20ca9aac225 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -60,12 +60,17 @@ class IntermediateTensors:
tensors: dict[str, torch.Tensor]
kv_connector_output: KVConnectorOutput | None
- def __init__(self, tensors):
+ def __init__(
+ self,
+ tensors: dict[str, torch.Tensor],
+ kv_connector_output: KVConnectorOutput | None = None,
+ ) -> None:
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
+ self.kv_connector_output = kv_connector_output
def __getitem__(self, key: str | slice):
if isinstance(key, str):
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index b7418cfb7cc75..ac4a71648cec8 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -77,6 +77,7 @@ class LazyConfigDict(dict):
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
+ afmoe="AfmoeConfig",
chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32=DeepseekV3Config,
@@ -476,6 +477,17 @@ def is_interleaved(config: PretrainedConfig) -> bool:
return False
+def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
+ """Detect if model uses custom attention mask generation for multimodal.
+
+ Some multimodal models require custom attention masks that enable
+ bidirectional attention between image tokens while maintaining causal
+ attention for text tokens. Currently applies to Gemma3 multimodal models.
+ """
+ architectures = getattr(config, "architectures", [])
+ return "Gemma3ForConditionalGeneration" in architectures
+
+
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
"""
Update kwargs for AutoConfig initialization based on model_type
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index ac612b255143c..dcae05a15fec3 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -7,6 +7,7 @@ Model configs may be defined in this directory for the following reasons:
- There is a need to override the existing config to support vLLM.
"""
+from vllm.transformers_utils.configs.afmoe import AfmoeConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig
@@ -40,6 +41,7 @@ from vllm.transformers_utils.configs.step3_vl import (
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
+ "AfmoeConfig",
"ChatGLMConfig",
"DeepseekVLV2Config",
"DotsOCRConfig",
diff --git a/vllm/transformers_utils/configs/afmoe.py b/vllm/transformers_utils/configs/afmoe.py
new file mode 100644
index 0000000000000..9b634fd037a33
--- /dev/null
+++ b/vllm/transformers_utils/configs/afmoe.py
@@ -0,0 +1,84 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class AfmoeConfig(PretrainedConfig):
+ model_type = "afmoe"
+
+ def __init__(
+ self,
+ vocab_size: int = 200_192,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ moe_intermediate_size: int = 1408,
+ num_hidden_layers: int = 32,
+ num_dense_layers: int = 1,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int | None = None,
+ head_dim: int = 128,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 131072,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-5,
+ use_cache: bool = True,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 10000.0,
+ rope_scaling: dict | None = None,
+ num_experts: int = 64,
+ num_experts_per_tok: int = 6,
+ num_shared_experts: int = 2,
+ num_expert_groups: int = 1,
+ num_limited_groups: int = 1,
+ score_func: str = "sigmoid",
+ route_norm: bool = True,
+ route_scale: float = 1.0,
+ global_attn_every_n_layers: int = 4,
+ sliding_window: int = 2048,
+ layer_types: list[str] | None = None,
+ attention_dropout: float = 0.0,
+ mup_enabled: bool = False,
+ n_group: int = 1,
+ topk_group: int = 1,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_dense_layers = num_dense_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = head_dim
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_experts = num_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_shared_experts = num_shared_experts
+ self.num_expert_groups = num_expert_groups
+ self.num_limited_groups = num_limited_groups
+ self.score_func = score_func
+ self.route_norm = route_norm
+ self.route_scale = route_scale
+
+ self.global_attn_every_n_layers = global_attn_every_n_layers
+ self.sliding_window = sliding_window
+ self.layer_types = layer_types
+ self.attention_dropout = attention_dropout
+
+ self.mup_enabled = mup_enabled
+ self.n_group = n_group
+ self.topk_group = topk_group
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["AfmoeConfig"]
diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py
new file mode 100644
index 0000000000000..2bf59c91a3bb1
--- /dev/null
+++ b/vllm/transformers_utils/gguf_utils.py
@@ -0,0 +1,166 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""GGUF utility functions."""
+
+from pathlib import Path
+
+import gguf
+from gguf.constants import Keys, VisionProjectorType
+from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def detect_gguf_multimodal(model: str) -> Path | None:
+ """Check if GGUF model has multimodal projector file.
+
+ Args:
+ model: Model path string
+
+ Returns:
+ Path to mmproj file if found, None otherwise
+ """
+ if not model.endswith(".gguf"):
+ return None
+
+ try:
+ model_path = Path(model)
+ if not model_path.is_file():
+ return None
+
+ model_dir = model_path.parent
+ mmproj_patterns = ["mmproj.gguf", "mmproj-*.gguf", "*mmproj*.gguf"]
+ for pattern in mmproj_patterns:
+ mmproj_files = list(model_dir.glob(pattern))
+ if mmproj_files:
+ return mmproj_files[0]
+ return None
+ except Exception:
+ return None
+
+
+def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | None":
+ """Extract vision config parameters from mmproj.gguf metadata.
+
+ Reads vision encoder configuration from GGUF metadata fields using
+ standardized GGUF constants. Automatically detects the projector type
+ (e.g., gemma3, llama4) and applies model-specific parameters accordingly.
+
+ The function extracts standard CLIP vision parameters from GGUF metadata
+ and applies projector-type-specific customizations. For unknown projector
+ types, it uses safe defaults from SiglipVisionConfig.
+
+ Args:
+ mmproj_path: Path to mmproj.gguf file (str or Path)
+
+ Returns:
+ SiglipVisionConfig if extraction succeeds, None if any required
+ field is missing from the GGUF metadata
+
+ Raises:
+ Exception: Exceptions from GGUF reading (file not found, corrupted
+ file, etc.) propagate directly from gguf.GGUFReader
+ """
+ reader = gguf.GGUFReader(str(mmproj_path))
+
+ # Detect projector type to apply model-specific parameters
+ projector_type = None
+ projector_type_field = reader.get_field(Keys.Clip.PROJECTOR_TYPE)
+ if projector_type_field:
+ try:
+ projector_type = bytes(projector_type_field.parts[-1]).decode("utf-8")
+ except (AttributeError, UnicodeDecodeError) as e:
+ logger.warning("Failed to decode projector type from GGUF: %s", e)
+
+ # Map GGUF field constants to SiglipVisionConfig parameters.
+ # Uses official GGUF constants from gguf-py for standardization.
+ # Format: {gguf_constant: (param_name, dtype)}
+ VISION_CONFIG_FIELDS = {
+ Keys.ClipVision.EMBEDDING_LENGTH: ("hidden_size", int),
+ Keys.ClipVision.FEED_FORWARD_LENGTH: ("intermediate_size", int),
+ Keys.ClipVision.BLOCK_COUNT: ("num_hidden_layers", int),
+ Keys.ClipVision.Attention.HEAD_COUNT: ("num_attention_heads", int),
+ Keys.ClipVision.IMAGE_SIZE: ("image_size", int),
+ Keys.ClipVision.PATCH_SIZE: ("patch_size", int),
+ Keys.ClipVision.Attention.LAYERNORM_EPS: ("layer_norm_eps", float),
+ }
+
+ # Extract and validate all required fields
+ config_params = {}
+ for gguf_key, (param_name, dtype) in VISION_CONFIG_FIELDS.items():
+ field = reader.get_field(gguf_key)
+ if field is None:
+ logger.warning(
+ "Missing required vision config field '%s' in mmproj.gguf",
+ gguf_key,
+ )
+ return None
+ # Extract scalar value from GGUF field and convert to target type
+ config_params[param_name] = dtype(field.parts[-1])
+
+ # Apply model-specific parameters based on projector type
+ if projector_type == VisionProjectorType.GEMMA3:
+ # Gemma3 doesn't use the vision pooling head (multihead attention)
+ # This is a vLLM-specific parameter used in SiglipVisionTransformer
+ config_params["vision_use_head"] = False
+ logger.info("Detected Gemma3 projector, disabling vision pooling head")
+ # Add other projector-type-specific customizations here as needed
+ # elif projector_type == VisionProjectorType.LLAMA4:
+ # config_params["vision_use_head"] = ...
+
+ # Create config with extracted parameters
+ # Note: num_channels and attention_dropout use SiglipVisionConfig defaults
+ # (3 and 0.0 respectively) which are correct for all models
+ config = SiglipVisionConfig(**config_params)
+
+ if projector_type:
+ logger.info(
+ "Extracted vision config from mmproj.gguf (projector_type: %s)",
+ projector_type,
+ )
+ else:
+ logger.info("Extracted vision config from mmproj.gguf metadata")
+
+ return config
+
+
+def maybe_patch_hf_config_from_gguf(
+ model: str,
+ hf_config: PretrainedConfig,
+) -> PretrainedConfig:
+ """Patch HF config for GGUF models.
+
+ Applies GGUF-specific patches to HuggingFace config:
+ 1. For multimodal models: patches architecture and vision config
+ 2. For all GGUF models: overrides vocab_size from embedding tensor
+
+ This ensures compatibility with GGUF models that have extended
+ vocabularies (e.g., Unsloth) where the GGUF file contains more
+ tokens than the HuggingFace tokenizer config specifies.
+
+ Args:
+ model: Model path string
+ hf_config: HuggingFace config to patch in-place
+
+ Returns:
+ Updated HuggingFace config
+ """
+ # Patch multimodal config if mmproj.gguf exists
+ mmproj_path = detect_gguf_multimodal(model)
+ if mmproj_path is not None:
+ vision_config = extract_vision_config_from_gguf(str(mmproj_path))
+
+ # Create HF config for Gemma3 multimodal
+ text_config = hf_config.get_text_config()
+ is_gemma3 = hf_config.model_type in ("gemma3", "gemma3_text")
+ if vision_config is not None and is_gemma3:
+ new_hf_config = Gemma3Config.from_text_vision_configs(
+ text_config=text_config,
+ vision_config=vision_config,
+ architectures=["Gemma3ForConditionalGeneration"],
+ )
+ hf_config = new_hf_config
+
+ return hf_config
diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py
index b3469c1b18f2d..8deacb5b07913 100644
--- a/vllm/transformers_utils/processor.py
+++ b/vllm/transformers_utils/processor.py
@@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
-from vllm.transformers_utils.utils import convert_model_repo_to_path
+from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
@@ -236,9 +236,20 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
+ if check_gguf_file(model_config.model):
+ assert not check_gguf_file(model_config.tokenizer), (
+ "For multimodal GGUF models, the original tokenizer "
+ "should be used to correctly load processor."
+ )
+ model = model_config.tokenizer
+ revision = model_config.tokenizer_revision
+ else:
+ model = model_config.model
+ revision = model_config.revision
+
return cached_get_processor_without_dynamic_kwargs(
- model_config.model,
- revision=model_config.revision,
+ model,
+ revision=revision,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, processor_cls, **kwargs),
@@ -339,9 +350,19 @@ def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
+ if check_gguf_file(model_config.model):
+ assert not check_gguf_file(model_config.tokenizer), (
+ "For multimodal GGUF models, the original tokenizer "
+ "should be used to correctly load image processor."
+ )
+ model = model_config.tokenizer
+ revision = model_config.tokenizer_revision
+ else:
+ model = model_config.model
+ revision = model_config.revision
return cached_get_image_processor(
- model_config.model,
- revision=model_config.revision,
+ model,
+ revision=revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
)
diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py
index 1ae42ba622dc4..901a64d9d2633 100644
--- a/vllm/transformers_utils/utils.py
+++ b/vllm/transformers_utils/utils.py
@@ -27,6 +27,7 @@ def is_cloud_storage(model_or_path: str) -> bool:
return is_s3(model_or_path) or is_gcs(model_or_path)
+@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py
index 79e5a4c302594..1209d64901bf5 100644
--- a/vllm/utils/flashinfer.py
+++ b/vllm/utils/flashinfer.py
@@ -319,14 +319,12 @@ def use_trtllm_attention(
# Environment variable not set - use auto-detection
if is_prefill:
# Prefill auto-detection
- use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
+ use_trtllm = kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
else:
# Decode auto-detection
- use_trtllm = (
- num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
- )
+ use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py
index f01d2c7a6a33d..ff0f0350fd941 100644
--- a/vllm/utils/import_utils.py
+++ b/vllm/utils/import_utils.py
@@ -18,6 +18,10 @@ from typing import Any
import regex as re
from typing_extensions import Never
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
@@ -62,6 +66,35 @@ def import_pynvml():
return pynvml
+@cache
+def import_triton_kernels():
+ """
+ For convenience, prioritize triton_kernels that is available in
+ `site-packages`. Use `vllm.third_party.triton_kernels` as a fall-back.
+ """
+ if _has_module("triton_kernels"):
+ import triton_kernels
+
+ logger.debug_once(
+ f"Loading module triton_kernels from {triton_kernels.__file__}.",
+ scope="local",
+ )
+ elif _has_module("vllm.third_party.triton_kernels"):
+ import vllm.third_party.triton_kernels as triton_kernels
+
+ logger.debug_once(
+ f"Loading module triton_kernels from {triton_kernels.__file__}.",
+ scope="local",
+ )
+ sys.modules["triton_kernels"] = triton_kernels
+ else:
+ logger.info_once(
+ "triton_kernels unavailable in this build. "
+ "Please consider installing triton_kernels from "
+ "https://github.com/triton-lang/triton/tree/main/python/triton_kernels"
+ )
+
+
def import_from_path(module_name: str, file_path: str | os.PathLike):
"""
Import a Python file according to its file path.
@@ -397,7 +430,12 @@ def has_deep_gemm() -> bool:
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
- return _has_module("triton_kernels")
+ is_available = _has_module("triton_kernels") or _has_module(
+ "vllm.third_party.triton_kernels"
+ )
+ if is_available:
+ import_triton_kernels()
+ return is_available
def has_tilelang() -> bool:
diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py
index ad454daa582eb..ea611848b0e81 100644
--- a/vllm/v1/attention/backends/rocm_aiter_fa.py
+++ b/vllm/v1/attention/backends/rocm_aiter_fa.py
@@ -729,7 +729,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
- min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
+ min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
@@ -759,7 +759,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
- min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
+ min_seqlen_q=1,
block_table=attn_metadata.block_table[
num_decodes : num_decodes + num_extends
],
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 63a1ff06e4049..7f405fc248ac2 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -185,12 +185,11 @@ class KVCacheManager:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
- # Prefix caching is disabled or
- # When the request requires prompt logprobs, we skip prefix caching.
- if not self.enable_caching or (
- request.sampling_params is not None
- and request.sampling_params.prompt_logprobs is not None
- ):
+ # We skip finding the prefix cache hit when prefix caching is
+ # disabled or the request is marked as skipping kv cache read
+ # (which happens when the request requires prompt logprobs
+ # or calls a pooling model with all pooling).
+ if not self.enable_caching or request.skip_reading_prefix_cache:
return self.empty_kv_cache_blocks, 0
# NOTE: When all tokens hit the cache, we must recompute the last token
diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py
index 0ad994c360b01..3214f65a09728 100644
--- a/vllm/v1/core/sched/async_scheduler.py
+++ b/vllm/v1/core/sched/async_scheduler.py
@@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
+ spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
+ cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (
request.num_computed_tokens
- == request.num_tokens + request.num_output_placeholders
+ == request.num_tokens
+ + request.num_output_placeholders
+ + cur_num_spec_tokens
):
- # The request will generate a new token in this scheduling step.
- # TODO(woosuk): Support speculative decoding.
- request.num_output_placeholders += 1
+ # The request will generate a new token plus num_spec_tokens
+ # in this scheduling step.
+ request.num_output_placeholders += 1 + cur_num_spec_tokens
+ # Add placeholders for the new tokens in spec_token_ids.
+ # Wwe will update the actual spec token ids in the worker process.
+ request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 041393d097b9a..e3260b3dae797 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
- self.max_model_len = self.scheduler_config.max_model_len
+ self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = (
self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events
@@ -353,7 +353,10 @@ class Scheduler(SchedulerInterface):
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (
- num_new_tokens + request.num_computed_tokens - request.num_tokens
+ num_new_tokens
+ + request.num_computed_tokens
+ - request.num_tokens
+ - request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
@@ -474,9 +477,9 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
- # KVTransfer: WAITING reqs have num_computed_tokens > 0
- # after async KV recvs are completed.
else:
+ # KVTransfer: WAITING reqs have num_computed_tokens > 0
+ # after async KV recvs are completed.
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
@@ -485,12 +488,12 @@ class Scheduler(SchedulerInterface):
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
- # KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
+ # KVTransfer: loading remote KV, do not allocate for new work.
assert num_external_computed_tokens > 0
num_new_tokens = 0
- # Number of tokens to be scheduled.
else:
+ # Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
@@ -783,9 +786,7 @@ class Scheduler(SchedulerInterface):
assert not scheduled_in_prev_step
resumed_req_ids.add(req_id)
if not scheduled_in_prev_step:
- all_token_ids[req_id] = req.all_token_ids[
- : req.num_computed_tokens + num_tokens
- ]
+ all_token_ids[req_id] = req.all_token_ids.copy()
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
@@ -1031,7 +1032,12 @@ class Scheduler(SchedulerInterface):
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens.
- request.num_computed_tokens -= num_rejected
+ if request.num_computed_tokens > 0:
+ request.num_computed_tokens -= num_rejected
+ # If async scheduling, num_output_placeholders also includes
+ # the scheduled spec tokens count and so is similarly adjusted.
+ if request.num_output_placeholders > 0:
+ request.num_output_placeholders -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index 48ea6ef8515c9..c160c7cbcab4a 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -14,7 +14,7 @@ import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
-from vllm.engine.protocol import Device, EngineClient
+from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
@@ -672,9 +672,7 @@ class AsyncLLM(EngineClient):
self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
- async def reset_prefix_cache(self, device: Device | None = None) -> None:
- if device == Device.CPU:
- raise ValueError("Not supported on CPU.")
+ async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index a6965182fc2ce..3a25827cec385 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -63,7 +63,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
-from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -181,11 +180,14 @@ class EngineCore:
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)
+ self.is_ec_producer = (
+ vllm_config.ec_transfer_config is not None
+ and vllm_config.ec_transfer_config.is_ec_producer
+ )
+ self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
+
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
- if (
- self.vllm_config.cache_config.enable_prefix_caching
- or kv_connector is not None
- ):
+ if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
@@ -198,6 +200,7 @@ class EngineCore:
self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue
)
+ self.async_scheduling = vllm_config.scheduler_config.async_scheduling
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
@@ -245,7 +248,7 @@ class EngineCore:
elapsed = time.time() - start
logger.info_once(
- ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
+ "init engine (profile, create kv cache, warmup model) took %.2f seconds",
elapsed,
scope="local",
)
@@ -311,6 +314,16 @@ class EngineCore:
)
raise err
+ def _log_err_callback(self, scheduler_output: SchedulerOutput):
+ """Log error details of a future that's not expected to return a result."""
+
+ def callback(f, sched_output=scheduler_output):
+ with self.log_error_detail(sched_output):
+ result = f.result()
+ assert result is None
+
+ return callback
+
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
@@ -322,26 +335,25 @@ class EngineCore:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
- with record_function_or_nullcontext("core step: schedule"):
- scheduler_output = self.scheduler.schedule()
+ scheduler_output = self.scheduler.schedule()
+ future = self.model_executor.execute_model(scheduler_output, non_block=True)
+ grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
+ with self.log_error_detail(scheduler_output):
+ model_output = future.result()
+ if model_output is None:
+ model_output = self.model_executor.sample_tokens(grammar_output)
- with record_function_or_nullcontext("core step: execute_model"):
- future = self.model_executor.execute_model(scheduler_output, non_block=True)
- grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
- with self.log_error_detail(scheduler_output):
- model_output = future.result()
- if model_output is None:
- model_output = self.model_executor.sample_tokens(grammar_output)
-
- with record_function_or_nullcontext("core step: update_from_output"):
- engine_core_outputs = self.scheduler.update_from_output(
- scheduler_output, model_output
- )
+ engine_core_outputs = self.scheduler.update_from_output(
+ scheduler_output, model_output
+ )
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def post_step(self, model_executed: bool) -> None:
- if self.use_spec_decode and model_executed:
+ # When using async scheduling we can't get draft token ids in advance,
+ # so we update draft token ids in the worker process and don't
+ # need to update draft token ids here.
+ if not self.async_scheduling and self.use_spec_decode and model_executed:
# Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None:
@@ -374,52 +386,34 @@ class EngineCore:
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
- with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
- scheduler_output = self.scheduler.schedule()
- with record_function_or_nullcontext(
- "core step_with_batch_queue: execute_model"
- ):
- exec_future = self.model_executor.execute_model(
- scheduler_output, non_block=True
- )
- model_executed = scheduler_output.total_num_scheduled_tokens > 0
+ scheduler_output = self.scheduler.schedule()
+ exec_future = self.model_executor.execute_model(
+ scheduler_output, non_block=True
+ )
+ if not self.is_ec_producer:
+ model_executed = scheduler_output.total_num_scheduled_tokens > 0
- if scheduler_output.pending_structured_output_tokens:
- with record_function_or_nullcontext(
- "core step_with_batch_queue: pending_structured_output_tokens"
- ):
- # We need to defer sampling until we have processed the model output
- # from the prior step.
- deferred_scheduler_output = scheduler_output
- # Block-wait for execute to return
- # (continues running async on the GPU).
- with self.log_error_detail(scheduler_output):
- exec_result = exec_future.result()
- assert exec_result is None
+ if self.is_pooling_model or not model_executed:
+ # No sampling required (no requests scheduled).
+ future = cast(Future[ModelRunnerOutput], exec_future)
else:
- with record_function_or_nullcontext(
- "core step_with_batch_queue: get_grammar_bitmask"
- ):
- # We aren't waiting for any tokens, get any grammar
- # output immediately.
+ exec_future.add_done_callback(self._log_err_callback(scheduler_output))
+
+ if not scheduler_output.pending_structured_output_tokens:
+ # We aren't waiting for any tokens, get any grammar output
+ # and sample immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
- # Block-wait for execute to return (continues running async on the GPU).
- with self.log_error_detail(scheduler_output):
- exec_result = exec_future.result()
-
- if exec_result is None:
- with record_function_or_nullcontext(
- "core step_with_batch_queue: sample_tokens"
- ):
- # Call sample tokens.
- future = self.model_executor.sample_tokens(
- grammar_output, non_block=True
- )
+ future = self.model_executor.sample_tokens(
+ grammar_output, non_block=True
+ )
else:
- # No sampling required (e.g. all requests finished).
- future = cast(Future[ModelRunnerOutput], exec_future)
+ # We need to defer sampling until we have processed the model output
+ # from the prior step.
+ deferred_scheduler_output = scheduler_output
+
+ if not deferred_scheduler_output:
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
@@ -436,34 +430,27 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
- with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
- # Block until the next result is available.
- future, scheduler_output = batch_queue.pop()
- with self.log_error_detail(scheduler_output):
- model_output = future.result()
- with record_function_or_nullcontext(
- "core step_with_batch_queue: update_from_output"
- ):
- engine_core_outputs = self.scheduler.update_from_output(
- scheduler_output, model_output
- )
+
+ # Block until the next result is available.
+ future, scheduler_output = batch_queue.pop()
+ with self.log_error_detail(scheduler_output):
+ model_output = future.result()
+
+ engine_core_outputs = self.scheduler.update_from_output(
+ scheduler_output, model_output
+ )
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
- with record_function_or_nullcontext(
- "core step_with_batch_queue: deferred_scheduler_output"
- ):
- # We now have the tokens needed to compute the bitmask for the
- # deferred request. Get the bitmask and call sample tokens.
- grammar_output = self.scheduler.get_grammar_bitmask(
- deferred_scheduler_output
- )
- future = self.model_executor.sample_tokens(
- grammar_output, non_block=True
- )
- batch_queue.appendleft((future, deferred_scheduler_output))
+ # We now have the tokens needed to compute the bitmask for the
+ # deferred request. Get the bitmask and call sample tokens.
+ grammar_output = self.scheduler.get_grammar_bitmask(
+ deferred_scheduler_output
+ )
+ future = self.model_executor.sample_tokens(grammar_output, non_block=True)
+ batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index 1db83446ba0b5..e403cea87788b 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs
-from vllm.engine.protocol import Device
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -321,7 +320,7 @@ class LLMEngine:
self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
- def reset_prefix_cache(self, device: Device | None = None):
+ def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):
diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py
index b618d23472651..63064a2c65d67 100644
--- a/vllm/v1/engine/logprobs.py
+++ b/vllm/v1/engine/logprobs.py
@@ -43,15 +43,22 @@ class LogprobsProcessor:
tokenizer: AnyTokenizer | None,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
- assert request.sampling_params is not None
- num_logprobs = request.sampling_params.logprobs
- num_prompt_logprobs = request.sampling_params.prompt_logprobs
+ sampling_params = request.sampling_params
+ assert sampling_params is not None
+ num_logprobs = sampling_params.logprobs
+ num_prompt_logprobs = sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.0),
- logprobs=(None if num_logprobs is None else create_sample_logprobs()),
+ logprobs=(
+ None
+ if num_logprobs is None
+ else create_sample_logprobs(sampling_params.flat_logprobs)
+ ),
prompt_logprobs=(
- None if num_prompt_logprobs is None else create_prompt_logprobs()
+ None
+ if num_prompt_logprobs is None
+ else create_prompt_logprobs(sampling_params.flat_logprobs)
),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 0404f6ff2771c..4cb911d8e22b7 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
+from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
@@ -149,6 +150,23 @@ class Processor:
raise ValueError(
"vLLM V1 does not support per request user provided logits processors."
)
+ # Async scheduling + spec decode currently incompatible with some
+ # sampling parameters.
+ if (
+ self.vllm_config.speculative_config is not None
+ and self.vllm_config.scheduler_config.async_scheduling
+ and (
+ params.frequency_penalty != 0.0
+ or params.presence_penalty != 0.0
+ or params.repetition_penalty != 1.0
+ or params.bad_words_token_ids
+ or params.structured_outputs
+ )
+ ):
+ raise ValueError(
+ "async scheduling with spec decoding doesn't yet support "
+ "penalties, bad words or structured outputs in sampling parameters."
+ )
def _validate_params(
self,
@@ -340,7 +358,12 @@ class Processor:
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
- n = len(data) if isinstance(data, list) else 1
+ # Hash each item for embedding inputs.
+ n = (
+ len(data)
+ if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
+ else 1
+ )
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py
index e74519b21aa6e..d65cad7af03d6 100644
--- a/vllm/v1/engine/utils.py
+++ b/vllm/v1/engine/utils.py
@@ -183,15 +183,19 @@ def set_device_control_env_var(
for engine subprocess.
"""
world_size = vllm_config.parallel_config.world_size
+ local_world_size = vllm_config.parallel_config.local_world_size
evar = current_platform.device_control_env_var
- value = get_device_indices(evar, local_dp_rank, world_size)
+ value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
with patch.dict(os.environ, values=((evar, value),)):
yield
def get_device_indices(
- device_control_env_var: str, local_dp_rank: int, world_size: int
+ device_control_env_var: str,
+ local_dp_rank: int,
+ world_size: int,
+ local_world_size: int | None = None,
):
"""
Returns a comma-separated string of device indices for the specified
@@ -200,10 +204,15 @@ def get_device_indices(
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
this will select devices 2 and 3 for local_dp_rank=1.
"""
+ if local_world_size is None:
+ local_world_size = world_size
try:
value = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
- for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
+ for i in range(
+ local_dp_rank * world_size,
+ local_dp_rank * world_size + local_world_size,
+ )
)
except IndexError as e:
raise Exception(
diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
index 881e6ef40aaf0..ad2ece50f9815 100644
--- a/vllm/v1/executor/multiproc_executor.py
+++ b/vllm/v1/executor/multiproc_executor.py
@@ -10,7 +10,7 @@ import time
import traceback
import weakref
from collections import deque
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
from concurrent.futures import Future, InvalidStateError
from contextlib import suppress
from dataclasses import dataclass
@@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
+ get_inner_dp_world_group,
get_pp_group,
get_tp_group,
)
@@ -90,6 +91,10 @@ class FutureWrapper(Future):
class MultiprocExecutor(Executor):
supports_pp: bool = True
+ def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
+ self.monitor_workers = monitor_workers
+ super().__init__(vllm_config)
+
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
@@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size
+ assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
+ f"global world_size ({self.parallel_config.world_size}) must be "
+ f"divisible by nnodes_within_dp "
+ f"({self.parallel_config.nnodes_within_dp}). "
+ )
+ self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
@@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
distributed_init_method = get_distributed_init_method(
get_loopback_ip(), get_open_port()
)
-
+ self.rpc_broadcast_mq: MessageQueue | None = None
+ scheduler_output_handle: Handle | None = None
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
- max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
- self.rpc_broadcast_mq = MessageQueue(
- self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
- )
- scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
-
+ if self.parallel_config.node_rank_within_dp == 0:
+ # For leader node within each dp rank,
+ # each dp will have its own leader multiproc executor.
+ max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
+ self.rpc_broadcast_mq = MessageQueue(
+ self.world_size,
+ self.local_world_size,
+ max_chunk_bytes=max_chunk_bytes,
+ connect_ip=self.parallel_config.master_addr,
+ )
+ scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
context = get_mp_context()
shared_worker_lock = context.Lock()
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
- for rank in range(self.world_size):
+ global_start_rank = (
+ self.local_world_size * self.parallel_config.node_rank_within_dp
+ )
+ for local_rank in range(self.local_world_size):
+ global_rank = global_start_rank + local_rank
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
- local_rank=rank,
- rank=rank,
+ local_rank=local_rank,
+ rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
@@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
+
+ # Wait for all local workers to be ready.
self.workers = WorkerProc.wait_for_ready(unready_workers)
+ # Start background thread to monitor worker health if not in headless mode.
+ if self.monitor_workers:
+ self.start_worker_monitor()
+
+ self.response_mqs = []
+ # Only leader node have remote response mqs
+ if self.parallel_config.node_rank_within_dp == 0:
+ for rank in range(self.world_size):
+ if rank < self.local_world_size:
+ local_message_queue = self.workers[rank].worker_response_mq
+ assert local_message_queue is not None
+ self.response_mqs.append(local_message_queue)
+ else:
+ remote_message_queue = self.workers[0].peer_worker_response_mqs[
+ rank
+ ]
+ assert remote_message_queue is not None
+ self.response_mqs.append(remote_message_queue)
+
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
- self.rpc_broadcast_mq.wait_until_ready()
- for w in self.workers:
- w.worker_response_mq.wait_until_ready()
- self.start_worker_monitor()
+ # Wait for all input mqs to be ready.
+ if self.rpc_broadcast_mq is not None:
+ self.rpc_broadcast_mq.wait_until_ready()
+ # Wait for all remote response mqs to be ready.
+ for response_mq in self.response_mqs:
+ response_mq.wait_until_ready()
success = True
finally:
if not success:
@@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
- def start_worker_monitor(self):
+ def start_worker_monitor(self, inline=False) -> None:
workers = self.workers
self_ref = weakref.ref(self)
@@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
_self.failure_callback = None
callback()
- Thread(
- target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
- ).start()
+ if not inline:
+ Thread(
+ target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
+ ).start()
+ return
+
+ monitor_workers()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
@@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""
-
+ assert self.rpc_broadcast_mq is not None, (
+ "collective_rpc should not be called on follower node"
+ )
if self.is_failed:
raise RuntimeError("Executor failed.")
@@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
- workers = (
- (self.workers[output_rank],) if output_rank is not None else self.workers
- )
+ response_mqs: Sequence[MessageQueue] = self.response_mqs
+ if output_rank is not None:
+ response_mqs = (response_mqs[output_rank],)
shutdown_event = self.shutdown_event
def get_response():
responses = []
- for w in workers:
+ for mq in response_mqs:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
try:
- status, result = w.worker_response_mq.dequeue(
+ status, result = mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event
)
except TimeoutError as e:
@@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
class WorkerProcHandle:
proc: BaseProcess
rank: int
- worker_response_mq: MessageQueue # The worker process writes to this MQ
+ # The worker process writes to this MQ in single-node mode
+ worker_response_mq: MessageQueue | None
+ # This is only non empty on driver node,
+ # the peer worker process i writes to MQ
+ # `peer_worker_response_mqs[i]`
+ peer_worker_response_mqs: list[MessageQueue | None]
death_writer: Connection | None = None
@classmethod
def from_unready_handle(
- cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
+ cls,
+ unready_handle: UnreadyWorkerProcHandle,
+ worker_response_mq: MessageQueue | None,
+ peer_worker_response_mqs: list[MessageQueue | None],
) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
+ peer_worker_response_mqs=peer_worker_response_mqs,
death_writer=unready_handle.death_writer,
)
@@ -411,6 +470,38 @@ class WorkerProc:
READY_STR = "READY"
+ def _init_message_queues(
+ self, input_shm_handle: Handle, vllm_config: VllmConfig
+ ) -> None:
+ if vllm_config.parallel_config.nnodes_within_dp == 1:
+ # Initialize MessageQueue for receiving SchedulerOutput
+ self.rpc_broadcast_mq = MessageQueue.create_from_handle(
+ input_shm_handle, self.worker.rank
+ )
+
+ # Initializes a message queue for sending the model output
+ self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
+ self.peer_response_handles = []
+ else:
+ # Initialize remote MessageQueue for receiving SchedulerOutput across nodes
+ self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
+ external_writer_handle=input_shm_handle,
+ # Since there is external_writer_handle from executor proc,
+ # where the ready signal from actual writer is sent out of the
+ # create_mq_broadcaster method and after this setup, we make it
+ # non blocking. The handshake will be triggered when
+ # worker.rpc_broadcast_mq.wait_until_ready() is called
+ blocking=False,
+ )
+ # Initializes remote message queue for sending the model output to the
+ # driver worker, exposing peer_response_handles for driver worker
+ # that include handles for all ranks
+ self.worker_response_mq, self.peer_response_handles = (
+ get_inner_dp_world_group().create_single_reader_mq_broadcasters(
+ reader_rank_in_group=0
+ )
+ )
+
def __init__(
self,
vllm_config: VllmConfig,
@@ -421,13 +512,15 @@ class WorkerProc:
shared_worker_lock: LockType,
):
self.rank = rank
- wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
+ wrapper = WorkerWrapperBase(
+ vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
+ )
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
- all_kwargs[rank] = {
+ all_kwargs[local_rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
@@ -438,14 +531,6 @@ class WorkerProc:
wrapper.init_worker(all_kwargs)
self.worker = wrapper
- # Initialize MessageQueue for receiving SchedulerOutput
- self.rpc_broadcast_mq = MessageQueue.create_from_handle(
- input_shm_handle, self.worker.rank
- )
-
- # Initializes a message queue for sending the model output
- self.worker_response_mq = MessageQueue(1, 1)
-
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
@@ -466,6 +551,7 @@ class WorkerProc:
)
# Load model
+ self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
@@ -512,6 +598,27 @@ class WorkerProc:
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
+ @staticmethod
+ def wait_for_response_handle_ready(
+ handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
+ ) -> WorkerProcHandle:
+ response_handle = handles["handle"]
+ worker_response_mq: MessageQueue | None = None
+ if len(response_handle.local_reader_ranks) > 0:
+ worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
+ peer_response_handles = handles["peer_response_handles"]
+ peer_worker_response_mqs = [
+ MessageQueue.create_from_handle(handle, -1)
+ if handle.remote_subscribe_addr is not None
+ else None
+ for handle in peer_response_handles
+ ]
+ return WorkerProcHandle.from_unready_handle(
+ proc_handle,
+ worker_response_mq,
+ peer_worker_response_mqs=peer_worker_response_mqs,
+ )
+
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle],
@@ -537,16 +644,10 @@ class WorkerProc:
if response["status"] != "READY":
raise e
- # Extract the message queue handle.
- worker_response_mq = MessageQueue.create_from_handle(
- response["handle"], 0
+ idx = unready_proc_handle.rank % len(ready_proc_handles)
+ ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
+ response, unready_proc_handle
)
- ready_proc_handles[unready_proc_handle.rank] = (
- WorkerProcHandle.from_unready_handle(
- unready_proc_handle, worker_response_mq
- )
- )
-
except EOFError:
e.__suppress_context__ = True
raise e from None
@@ -618,12 +719,14 @@ class WorkerProc:
{
"status": WorkerProc.READY_STR,
"handle": worker.worker_response_mq.export_handle(),
+ "peer_response_handles": worker.peer_response_handles,
}
)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
- worker.rpc_broadcast_mq.wait_until_ready()
+ if worker.rpc_broadcast_mq is not None:
+ worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py
index 119e4c0818316..406eafcd339b0 100644
--- a/vllm/v1/executor/ray_executor.py
+++ b/vllm/v1/executor/ray_executor.py
@@ -99,6 +99,11 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
+ self.vllm_config.ec_transfer_config is None
+ or not self.vllm_config.ec_transfer_config.is_ec_producer
+ )
+
self.scheduler_output: SchedulerOutput | None = None
@property
@@ -395,6 +400,12 @@ class RayDistributedExecutor(Executor):
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
+
+ if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
+ # Model will not execute, call model runner immediately.
+ return self._execute_dag(scheduler_output, None, non_block)
+
+ # Model will execute, defer to sample_tokens() call.
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
@@ -417,10 +428,18 @@ class RayDistributedExecutor(Executor):
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
- return None # noqa
+ return COMPLETED_NONE_FUTURE if non_block else None # noqa
self.scheduler_output = None
+ return self._execute_dag(scheduler_output, grammar_output, non_block)
+
+ def _execute_dag(
+ self,
+ scheduler_output: SchedulerOutput,
+ grammar_output: "GrammarOutput | None",
+ non_block: bool = False,
+ ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py
index 646f9d0d75423..0f2ec4a1b41f3 100644
--- a/vllm/v1/kv_offload/worker/cpu_gpu.py
+++ b/vllm/v1/kv_offload/worker/cpu_gpu.py
@@ -68,9 +68,9 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self.h2d_stream = torch.cuda.Stream()
# job_id -> transfer cuda event
- self.transfer_events: dict[int, torch.cuda.Event] = {}
+ self.transfer_events: dict[int, torch.Event] = {}
# list of cuda events available for re-use
- self.events_pool: list[torch.cuda.Event] = []
+ self.events_pool: list[torch.Event] = []
pin_memory = is_pin_memory_available()
@@ -153,7 +153,7 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
)
src_to_dst_tensor = torch.from_numpy(src_to_dst)
- event = self.events_pool.pop() if self.events_pool else torch.cuda.Event()
+ event = self.events_pool.pop() if self.events_pool else torch.Event()
with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip(
src_tensors, dst_tensors, self.kv_dim_before_num_blocks
diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py
index 21280b9c84cf2..cb36e7973650e 100644
--- a/vllm/v1/metrics/loggers.py
+++ b/vllm/v1/metrics/loggers.py
@@ -494,6 +494,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
gauge_kv_cache_usage = self._gauge_cls(
name="vllm:kv_cache_usage_perc",
documentation="KV-cache usage. 1 means 100 percent usage.",
+ multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_kv_cache_usage = make_per_engine(
diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
index 60ee9671e4977..c0b2835c3124c 100644
--- a/vllm/v1/outputs.py
+++ b/vllm/v1/outputs.py
@@ -220,7 +220,7 @@ def make_empty_encoder_model_runner_output(
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
# No tokens generated yet ⇒ one empty list per request
- sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
+ sampled_token_ids: list[list[int]] = [np.array([0]) for _ in req_ids]
# Pooler outputs are not available yet ⇒ use None placeholders
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
index 7a5f1183ed48e..3d92906fbf4b1 100644
--- a/vllm/v1/request.py
+++ b/vllm/v1/request.py
@@ -127,6 +127,8 @@ class Request:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
+ self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
+
@classmethod
def from_engine_core_request(
cls,
@@ -180,6 +182,19 @@ class Request:
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
+ def get_skip_reading_prefix_cache(self) -> bool:
+ if (
+ self.sampling_params is not None
+ and self.sampling_params.skip_reading_prefix_cache is not None
+ ):
+ return self.sampling_params.skip_reading_prefix_cache
+ elif (
+ self.pooling_params is not None
+ and self.pooling_params.skip_reading_prefix_cache is not None
+ ):
+ return self.pooling_params.skip_reading_prefix_cache
+ return False
+
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py
index 5992c4066c9cb..8b174af4c7794 100644
--- a/vllm/v1/sample/logits_processor/__init__.py
+++ b/vllm/v1/sample/logits_processor/__init__.py
@@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
# Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
- "Custom logits processors are not supportedwhen speculative decoding is enabled."
+ "Custom logits processors are not supported when speculative decoding is enabled."
)
LOGITSPROCS_GROUP = "vllm.logits_processors"
diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py
index 02ea658b7f20e..c6c7e924175f7 100644
--- a/vllm/v1/sample/ops/topk_topp_sampler.py
+++ b/vllm/v1/sample/ops/topk_topp_sampler.py
@@ -7,6 +7,7 @@ import torch.nn as nn
from packaging import version
from vllm import envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
@@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
self.forward = self.forward_native
else:
self.forward = self.forward_cpu
+ elif (
+ logprobs_mode not in ("processed_logits", "processed_logprobs")
+ and rocm_aiter_ops.is_enabled()
+ ):
+ import aiter.ops.sampling # noqa: F401
+
+ self.aiter_ops = torch.ops.aiter
+ logger.info_once(
+ "Using aiter sampler on ROCm (lazy import, sampling-only)."
+ )
+ self.forward = self.forward_hip
else:
self.forward = self.forward_native
@@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
+ def forward_hip(
+ self,
+ logits: torch.Tensor,
+ generators: dict[int, torch.Generator],
+ k: torch.Tensor | None,
+ p: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Optimized ROCm/aiter path (same structure as forward_cuda)."""
+ if (k is None and p is None) or generators:
+ if generators:
+ logger.warning_once(
+ "aiter sampler does not support per-request generators; "
+ "falling back to PyTorch-native."
+ )
+ return self.forward_native(logits, generators, k, p)
+ assert self.logprobs_mode not in (
+ "processed_logits",
+ "processed_logprobs",
+ ), "aiter sampler does not support returning logits/logprobs."
+ return self.aiter_sample(logits, k, p, generators), None
+
+ def aiter_sample(
+ self,
+ logits: torch.Tensor,
+ k: torch.Tensor | None,
+ p: torch.Tensor | None,
+ generators: dict[int, torch.Generator],
+ ) -> torch.Tensor:
+ """Sample from logits using aiter ops."""
+ use_top_k = k is not None
+ use_top_p = p is not None
+ # Joint k+p path
+ if use_top_p and use_top_k:
+ probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
+ next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
+ probs,
+ None,
+ *_to_tensor_scalar_tuple(k),
+ *_to_tensor_scalar_tuple(p),
+ deterministic=True,
+ )
+ return next_token_ids.view(-1)
+ # Top-p only path
+ elif use_top_p:
+ probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
+ next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
+ probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
+ )
+ return next_token_ids.view(-1)
+ # Top-k only path
+ elif use_top_k:
+ probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
+ renorm_probs = self.aiter_ops.top_k_renorm_probs(
+ probs, *_to_tensor_scalar_tuple(k)
+ )
+ return torch.multinomial(renorm_probs, num_samples=1).view(-1)
+ raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
+
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@@ -288,3 +358,10 @@ def flashinfer_sample(
)
return next_token_ids.view(-1)
+
+
+def _to_tensor_scalar_tuple(x):
+ if isinstance(x, torch.Tensor):
+ return (x, 0)
+ else:
+ return (None, x)
diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py
index f3b34544f8d91..5bf2503c3027d 100644
--- a/vllm/v1/spec_decode/eagle.py
+++ b/vllm/v1/spec_decode/eagle.py
@@ -397,10 +397,13 @@ class EagleProposer:
positions += 1
exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
-
+ # For data integrity when async scheduling, we shouldn't use in place
+ # operations in case they are modified in next step's `prepare_input`
+ # of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
- common_attn_metadata.seq_lens_cpu += 1
+ # This is an out-of-place operation to avoid modifying the original tensor.
+ common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
@@ -991,6 +994,7 @@ class EagleProposer:
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
+
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"):
@@ -1002,52 +1006,92 @@ class EagleProposer:
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
- # Check if shapes match and we found the embedding
- eagle_shape = self.model.model.embed_tokens.weight.shape
- target_shape = target_embed_tokens.weight.shape
- if eagle_shape == target_shape:
- logger.info(
- "Assuming the EAGLE head shares the same vocab embedding"
- " with the target model."
- )
- del self.model.model.embed_tokens
- self.model.model.embed_tokens = target_embed_tokens
+ share_embeddings = False
+ if hasattr(self.model, "has_own_embed_tokens"):
+ # EAGLE model
+ if not self.model.has_own_embed_tokens:
+ share_embeddings = True
+ logger.info(
+ "Detected EAGLE model without its own embed_tokens in the"
+ " checkpoint. Sharing target model embedding weights with the"
+ " draft model."
+ )
+ elif (
+ isinstance(target_embed_tokens.weight, torch.Tensor)
+ and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
+ and torch.equal(
+ target_embed_tokens.weight, self.model.model.embed_tokens.weight
+ )
+ ):
+ share_embeddings = True
+ logger.info(
+ "Detected EAGLE model with embed_tokens identical to the target"
+ " model. Sharing target model embedding weights with the draft"
+ " model."
+ )
+ else:
+ logger.info(
+ "Detected EAGLE model with distinct embed_tokens weights. "
+ "Keeping separate embedding weights from the target model."
+ )
else:
+ # MTP model
+ share_embeddings = True
logger.info(
- "The EAGLE head's vocab embedding will be loaded separately"
- " from the target model."
+ "Detected MTP model. "
+ "Sharing target model embedding weights with the draft model."
)
+
+ if share_embeddings:
+ if hasattr(self.model.model, "embed_tokens"):
+ del self.model.model.embed_tokens
+ self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
- "The EAGLE head's vocab embedding will be loaded separately"
+ "The draft model's vocab embedding will be loaded separately"
" from the target model."
)
# share lm_head with the target model if needed
- # some model definition do not define lm_head explicitly
- # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
- if self.vllm_config.speculative_config.method != "eagle3":
- if hasattr(target_language_model, "lm_head"):
- logger.info("Loading EAGLE LM head weights from the target model.")
- self.model.lm_head = target_language_model.lm_head
- else:
- if (
- hasattr(self.model, "lm_head")
- and hasattr(target_language_model, "lm_head")
- and self.model.lm_head.weight.shape
- == target_language_model.lm_head.weight.shape
- ):
+ share_lm_head = False
+ if hasattr(self.model, "has_own_lm_head"):
+ # EAGLE model
+ if not self.model.has_own_lm_head:
+ share_lm_head = True
logger.info(
- "Assuming the EAGLE head shares the same lm_head"
- " with the target model."
+ "Detected EAGLE model without its own lm_head in the checkpoint. "
+ "Sharing target model lm_head weights with the draft model."
+ )
+ elif (
+ hasattr(target_language_model, "lm_head")
+ and isinstance(target_language_model.lm_head.weight, torch.Tensor)
+ and isinstance(self.model.lm_head.weight, torch.Tensor)
+ and torch.equal(
+ target_language_model.lm_head.weight, self.model.lm_head.weight
+ )
+ ):
+ share_lm_head = True
+ logger.info(
+ "Detected EAGLE model with lm_head identical to the target model. "
+ "Sharing target model lm_head weights with the draft model."
)
- del self.model.lm_head
- self.model.lm_head = target_language_model.lm_head
else:
logger.info(
- "The EAGLE head's lm_head will be loaded separately"
- " from the target model."
+ "Detected EAGLE model with distinct lm_head weights. "
+ "Keeping separate lm_head weights from the target model."
)
+ else:
+ # MTP model
+ share_lm_head = True
+ logger.info(
+ "Detected MTP model. "
+ "Sharing target model lm_head weights with the draft model."
+ )
+
+ if share_lm_head and hasattr(target_language_model, "lm_head"):
+ if hasattr(self.model, "lm_head"):
+ del self.model.lm_head
+ self.model.lm_head = target_language_model.lm_head
@torch.inference_mode()
def dummy_run(
diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py
index a401f6d74cdd5..29099d1e9b17e 100644
--- a/vllm/v1/utils.py
+++ b/vllm/v1/utils.py
@@ -97,6 +97,9 @@ class ConstantList(Generic[T], Sequence):
def __repr__(self):
return f"ConstantList({self._x})"
+ def copy(self) -> list[T]:
+ return self._x.copy()
+
class CpuGpuBuffer:
"""Buffer to easily copy tensors between CPU and GPU."""
diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py
index ceb1cf64b5889..6bfbc32d598fa 100644
--- a/vllm/v1/worker/cpu_model_runner.py
+++ b/vllm/v1/worker/cpu_model_runner.py
@@ -80,9 +80,6 @@ class CPUModelRunner(GPUModelRunner):
def _sync_device(self) -> None:
pass
- def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
- return sampled_token_ids.tolist()
-
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
# Note: For CPU backend, dp padding is not required for now.
return 0, None
@@ -99,14 +96,14 @@ def _torch_cuda_wrapper():
def __init__(self, *args, **kwargs) -> None:
pass
- cuda_event = torch.cuda.Event
+ cuda_event = torch.Event
cuda_stream = torch.cuda.Stream
try:
- torch.cuda.Event = _EventPlaceholder
+ torch.Event = _EventPlaceholder
torch.cuda.Stream = _StreamPlaceholder
yield
finally:
- torch.cuda.Event = cuda_event
+ torch.Event = cuda_event
torch.cuda.Stream = cuda_stream
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 393181f543d2e..023b5edb2c340 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -46,6 +46,9 @@ class CachedRequestState:
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
+ # Used when both async_scheduling and spec_decode are enabled.
+ prev_num_draft_len: int = 0
+
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
@@ -262,7 +265,7 @@ class InputBatch:
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
- self.async_copy_ready_event: torch.cuda.Event | None = None
+ self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
@@ -888,7 +891,7 @@ class InputBatch:
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
- async_copy_ready_event: torch.cuda.Event,
+ async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index d0d6164180e66..506118d2d762b 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
-from copy import deepcopy
+from copy import copy, deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -179,16 +179,18 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
logprobs_tensors: torch.Tensor | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
+ vocab_size: int,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
- self.async_copy_ready_event = torch.cuda.Event()
+ self.async_copy_ready_event = torch.Event()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
+ self.vocab_size = vocab_size
self._logprobs_tensors = logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it.
@@ -215,10 +217,16 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
-
- valid_sampled_token_ids: list[np.ndarray] = [
- row for row in self.sampled_token_ids_cpu.numpy()
- ]
+ max_gen_len = self.sampled_token_ids_cpu.shape[-1]
+ if max_gen_len == 1:
+ valid_sampled_token_ids: list[np.ndarray] = [
+ row for row in self.sampled_token_ids_cpu.numpy()
+ ]
+ else:
+ valid_sampled_token_ids = RejectionSampler.parse_output(
+ self.sampled_token_ids_cpu,
+ self.vocab_size,
+ )
for i in self._invalid_req_indices:
valid_sampled_token_ids[i] = np.array([])
@@ -242,7 +250,6 @@ class ExecuteModelState(NamedTuple):
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
- kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None
@@ -317,6 +324,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@@ -377,6 +385,10 @@ class GPUModelRunner(
)
self.rejection_sampler = RejectionSampler(self.sampler)
+ self.num_spec_tokens = 0
+ if self.speculative_config:
+ self.num_spec_tokens = self.speculative_config.num_speculative_tokens
+
# Request states.
self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
@@ -423,10 +435,10 @@ class GPUModelRunner(
self.async_output_copy_stream: torch.cuda.Stream | None = None
# cuda event to synchronize use of reused CPU tensors between steps
# when async scheduling is enabled.
- self.prepare_inputs_event: torch.cuda.Event | None = None
+ self.prepare_inputs_event: torch.Event | None = None
if self.use_async_scheduling:
self.async_output_copy_stream = torch.cuda.Stream()
- self.prepare_inputs_event = torch.cuda.Event()
+ self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
if (
@@ -513,11 +525,7 @@ class GPUModelRunner(
self.max_num_tokens, dtype=torch.int32, device=self.device
)
- self.uniform_decode_query_len = (
- 1
- if not self.speculative_config
- else 1 + self.speculative_config.num_speculative_tokens
- )
+ self.uniform_decode_query_len = 1 + self.num_spec_tokens
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
@@ -541,7 +549,7 @@ class GPUModelRunner(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
- self.transfer_event = torch.cuda.Event()
+ self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1),
dtype=torch.int64,
@@ -549,8 +557,23 @@ class GPUModelRunner(
pin_memory=self.pin_memory,
)
+ # Pre-allocated tensor for copying valid sampled token counts to CPU,
+ # with dedicated stream for overlapping and event for coordination.
+ self.valid_sampled_token_count_event: torch.Event | None = None
+ self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
+ if self.use_async_scheduling and self.num_spec_tokens:
+ self.valid_sampled_token_count_event = torch.Event()
+ self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
+ self.valid_sampled_token_count_cpu = torch.empty(
+ self.max_num_reqs,
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
+ self.kv_connector_output: KVConnectorOutput | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
@@ -630,16 +653,6 @@ class GPUModelRunner(
return
if self.reorder_batch_threshold is not None:
- # NOTE(lucas): currently no backend supports the custom masking
- # required for DCP with q_len > 1, so we assert here. Remove this
- # assert once the custom mask is support is added to FA3.
- if (
- self.dcp_world_size > 1
- and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
- ):
- assert self.reorder_batch_threshold == 1, (
- "DCP not support reorder_batch_threshold > 1 now."
- )
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
@@ -746,17 +759,45 @@ class GPUModelRunner(
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
+
+ # Wait until valid_sampled_tokens_count is copied to cpu,
+ # then use it to update actual num_computed_tokens of each request.
+ valid_sampled_token_count = self._get_valid_sampled_token_count()
+
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+
+ # prev_num_draft_len is used in async scheduling mode with
+ # spec decode. it indicates if need to update num_computed_tokens
+ # of the request. for example:
+ # fist step: num_computed_tokens = 0, spec_tokens = [],
+ # prev_num_draft_len = 0.
+ # second step: num_computed_tokens = 100(prompt lenth),
+ # spec_tokens = [a,b], prev_num_draft_len = 0.
+ # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
+ # prev_num_draft_len = 2.
+ # num_computed_tokens in first step and second step does't contain
+ # the spec tokens length, but in third step it contains the
+ # spec tokens length. we only need to update num_computed_tokens
+ # when prev_num_draft_len > 0.
+ if req_state.prev_num_draft_len:
+ if req_index is None:
+ req_state.prev_num_draft_len = 0
+ else:
+ assert self.input_batch.prev_req_id_to_index is not None
+ prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
+ num_accepted = valid_sampled_token_count[prev_req_index] - 1
+ num_rejected = req_state.prev_num_draft_len - num_accepted
+ num_computed_tokens -= num_rejected
+ req_state.output_token_ids.extend([-1] * num_accepted)
# Update the cached states.
-
req_state.num_computed_tokens = num_computed_tokens
- req_index = self.input_batch.req_id_to_index.get(req_id)
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
@@ -833,8 +874,11 @@ class GPUModelRunner(
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, []
)
- if spec_token_ids:
- num_spec_tokens = len(spec_token_ids)
+ num_spec_tokens = len(spec_token_ids)
+ # For async scheduling, token_ids_cpu assigned from
+ # spec_token_ids are placeholders and will be overwritten in
+ # _prepare_input_ids.
+ if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
@@ -850,6 +894,15 @@ class GPUModelRunner(
# even when speculative decoding is enabled.
self.input_batch.spec_token_ids[req_index] = spec_token_ids
+ # there are no draft tokens with async scheduling,
+ # we clear the spec_decoding info in scheduler_output and
+ # use normal sampling but rejection_sampling.
+ if self.use_async_scheduling:
+ req_state.prev_num_draft_len = num_spec_tokens
+ if num_spec_tokens and self._draft_token_ids is None:
+ scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
+ scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
+ scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
@@ -969,7 +1022,10 @@ class GPUModelRunner(
return cu_num_tokens, arange
def _prepare_input_ids(
- self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray
+ self,
+ scheduler_output: "SchedulerOutput",
+ total_num_scheduled_tokens: int,
+ cu_num_tokens: np.ndarray,
) -> None:
"""Prepare the input IDs for the current batch.
@@ -990,21 +1046,43 @@ class GPUModelRunner(
# on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
- flattened_indices = []
- prev_common_req_indices = []
+ sample_flattened_indices: list[int] = []
+ spec_flattened_indices: list[int] = []
+ prev_common_req_indices: list[int] = []
+ prev_draft_token_indices: list[int] = []
indices_match = True
max_flattened_index = -1
+ total_num_spec_tokens = 0
+ scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
+
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
+ draft_len = len(scheduled_spec_tokens.get(req_id, ()))
+ total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
- flattened_indices.append(flattened_index)
+ # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
+ # sample_flattened_indices = [0, 2, 5]
+ # spec_flattened_indices = [1, 3, 4, 6, 7]
+ sample_flattened_indices.append(flattened_index - draft_len)
+ spec_flattened_indices.extend(
+ range(flattened_index - draft_len + 1, flattened_index + 1)
+ )
+ start = prev_index * self.num_spec_tokens
+ # prev_draft_token_indices is used to find which draft_tokens_id
+ # should be copied to input_ids
+ # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
+ # flatten draft_tokens_id [1,2,3,4,5,6]
+ # draft_len of each request [1, 2, 1]
+ # then prev_draft_token_indices is [0, 2, 3, 4]
+ prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
- num_commmon_tokens = len(flattened_indices)
- if num_commmon_tokens < total_num_scheduled_tokens:
+ num_commmon_tokens = len(sample_flattened_indices)
+ total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
+ if num_commmon_tokens < total_without_spec:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
@@ -1028,20 +1106,43 @@ class GPUModelRunner(
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
- input_ids_index_tensor = torch.tensor(
- flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
+ sampled_tokens_index_tensor = torch.tensor(
+ sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
- index=input_ids_index_tensor,
+ index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0
],
)
+ # Scatter the draft tokens after the sampled tokens are scattered.
+ if self._draft_token_ids is None or not spec_flattened_indices:
+ return
+
+ assert isinstance(self._draft_token_ids, torch.Tensor)
+ draft_tokens_index_tensor = torch.tensor(
+ spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
+ ).to(self.device, non_blocking=True)
+ prev_draft_token_indices_tensor = torch.tensor(
+ prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
+ ).to(self.device, non_blocking=True)
+
+ # because input_ids dtype is torch.int32,
+ # so convert draft_token_ids to torch.int32 here.
+ draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
+ self._draft_token_ids = None
+
+ self.input_ids.gpu.scatter_(
+ dim=0,
+ index=draft_tokens_index_tensor,
+ src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
+ )
+
def _get_encoder_seq_lens(
self,
scheduled_encoder_inputs: dict[str, list[int]],
@@ -1228,7 +1329,11 @@ class GPUModelRunner(
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# Copy the tensors to the GPU.
- self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
+ self._prepare_input_ids(
+ scheduler_output,
+ total_num_scheduled_tokens,
+ cu_num_tokens,
+ )
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -2242,6 +2347,24 @@ class GPUModelRunner(
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
+
+ # Generate custom attention masks for models that require them.
+ # V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
+ # Check mm_features (mm_embeds is empty during decode).
+ has_mm_features = any(
+ req_state.mm_features for req_state in self.requests.values()
+ )
+ if (
+ self.uses_custom_attention_masks
+ and has_mm_features
+ and hasattr(self.model, "generate_attention_masks")
+ ):
+ mask_kwargs = self.model.generate_attention_masks(
+ self.input_ids.gpu[:num_scheduled_tokens],
+ self.positions.gpu[:num_scheduled_tokens],
+ mask_dtype=self.model.dtype,
+ )
+ model_kwargs.update(mask_kwargs)
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
@@ -2387,12 +2510,14 @@ class GPUModelRunner(
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
invalid_req_indices_set = set(invalid_req_indices)
- assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the GPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
- self.input_batch.prev_sampled_token_ids = sampled_token_ids
+ # With spec decoding, this is done in propose_draft_token_ids().
+ if self.input_batch.prev_sampled_token_ids is None:
+ assert sampled_token_ids.shape[-1] == 1
+ self.input_batch.prev_sampled_token_ids = sampled_token_ids
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
@@ -2527,6 +2652,21 @@ class GPUModelRunner(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
+
+ # self._draft_token_ids is None when `input_fits_in_drafter=False`
+ # and there is no draft tokens scheduled. so it need to update the
+ # spec_decoding info in scheduler_output with async_scheduling.
+ # use deepcopy to avoid the modification has influence on the
+ # scheduler_output in engine core process.
+ # TODO(Ronald1995): deepcopy is expensive when there is a large
+ # number of requests, optimize it later.
+ if (
+ self.use_async_scheduling
+ and self.num_spec_tokens
+ and self._draft_token_ids is None
+ ):
+ scheduler_output = deepcopy(scheduler_output)
+
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep():
@@ -2542,6 +2682,18 @@ class GPUModelRunner(
return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
+ if (
+ self.parallel_config.distributed_executor_backend
+ == "external_launcher"
+ and self.parallel_config.data_parallel_size > 1
+ ):
+ # this is a corner case when both external launcher
+ # and DP are enabled, num_scheduled_tokens could be
+ # 0, and has_unfinished_requests in the outer loop
+ # returns True. before returning early here we call
+ # dummy run to ensure coordinate_batch_across_dp
+ # is called into to avoid out of sync issues.
+ self._dummy_run(1)
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
@@ -2682,6 +2834,7 @@ class GPUModelRunner(
# Return the intermediate tensors.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
+ self.kv_connector_output = kv_connector_output
return hidden_states
if self.is_pooling_model:
@@ -2732,18 +2885,31 @@ class GPUModelRunner(
hidden_states,
sample_hidden_states,
aux_hidden_states,
- kv_connector_output,
ec_connector_output,
)
+ self.kv_connector_output = kv_connector_output
return None
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
+ kv_connector_output = self.kv_connector_output
+ self.kv_connector_output = None
+
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
- return None # noqa
+ if not kv_connector_output:
+ return None # noqa
+
+ # In case of PP with kv transfer, we need to pass through the
+ # kv_connector_output
+ if kv_connector_output.is_empty():
+ return EMPTY_MODEL_RUNNER_OUTPUT
+
+ output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
+ output.kv_connector_output = kv_connector_output
+ return output
# Unpack ephemeral state.
(
@@ -2754,7 +2920,6 @@ class GPUModelRunner(
hidden_states,
sample_hidden_states,
aux_hidden_states,
- kv_connector_output,
ec_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
@@ -2769,6 +2934,8 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
+ self.input_batch.prev_sampled_token_ids = None
+
def propose_draft_token_ids(
sampled_token_ids: torch.Tensor | list[np.ndarray],
) -> None:
@@ -2802,14 +2969,29 @@ class GPUModelRunner(
self.speculative_config.draft_model_config.max_model_len
)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
- spec_decode_common_attn_metadata.max_seq_len
- + self.speculative_config.num_speculative_tokens
+ spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= effective_drafter_max_model_len
)
- if use_padded_batch_for_eagle and input_fits_in_drafter:
- # EAGLE speculative decoding can use the GPU sampled tokens
- # as inputs, and does not need to wait for bookkeeping to finish.
- propose_draft_token_ids(sampler_output.sampled_token_ids)
+ if use_padded_batch_for_eagle:
+ sampled_token_ids = sampler_output.sampled_token_ids
+ if input_fits_in_drafter:
+ # EAGLE speculative decoding can use the GPU sampled tokens
+ # as inputs, and does not need to wait for bookkeeping to finish.
+ propose_draft_token_ids(sampled_token_ids)
+ elif self.valid_sampled_token_count_event is not None:
+ next_token_ids, valid_sampled_tokens_count = (
+ self.drafter.prepare_next_token_ids_padded(
+ spec_decode_common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_indices.gpu,
+ self.num_discarded_requests,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
@@ -2866,12 +3048,13 @@ class GPUModelRunner(
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
+ vocab_size=self.input_batch.vocab_size,
)
with record_function_or_nullcontext(
"gpu_model_runner: set_async_sampled_token_ids"
):
# Save ref of sampled_token_ids CPU tensor if the batch contains
- # any requests with sampling params that that require output ids.
+ # any requests with sampling params that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
@@ -2890,6 +3073,37 @@ class GPUModelRunner(
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
+ def _copy_valid_sampled_token_count(
+ self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
+ ) -> None:
+ if self.valid_sampled_token_count_event is None:
+ return
+
+ default_stream = torch.cuda.current_stream()
+ # Initialize a new stream to overlap the copy operation with
+ # prepare_input of draft model.
+ with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
+ self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
+ counts = valid_sampled_tokens_count
+ counts_cpu = self.valid_sampled_token_count_cpu
+ counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
+ self.valid_sampled_token_count_event.record()
+
+ self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
+
+ def _get_valid_sampled_token_count(self) -> list[int]:
+ # Wait until valid_sampled_tokens_count is copied to cpu,
+ prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
+ if (
+ self.valid_sampled_token_count_event is None
+ or prev_sampled_token_ids is None
+ ):
+ return []
+
+ counts_cpu = self.valid_sampled_token_count_cpu
+ self.valid_sampled_token_count_event.synchronize()
+ return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
+
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
@@ -2977,6 +3191,9 @@ class GPUModelRunner(
self.num_discarded_requests,
)
)
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
if spec_decode_metadata is None:
token_indices_to_sample = None
@@ -3542,7 +3759,7 @@ class GPUModelRunner(
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
- seq_lens = max_query_len
+ seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
@@ -4342,6 +4559,22 @@ class GPUModelRunner(
"and make sure compilation mode is VLLM_COMPILE"
)
+ # if we have dedicated decode cudagraphs, and spec-decode is enabled,
+ # we need to adjust the cudagraph sizes to be a multiple of the uniform
+ # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
+ # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
+ # Will be removed in the near future when we have seperate cudagraph capture
+ # sizes for decode and mixed prefill-decode.
+ if (
+ cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
+ and cudagraph_mode.separate_routine()
+ and self.uniform_decode_query_len > 1
+ ):
+ self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
+ self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
+ )
+ self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
+
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.cudagraph_dispatcher.initialize_cudagraph_keys(
@@ -4479,11 +4712,7 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
- num_speculative_tokens=(
- self.vllm_config.speculative_config.num_speculative_tokens
- if self.vllm_config.speculative_config
- else 0
- ),
+ num_speculative_tokens=self.num_spec_tokens,
)
def _allocate_kv_cache_tensors(
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 283e3744bcf6f..315f01b68499a 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
-import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
@@ -45,7 +44,6 @@ from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
- EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
ModelRunnerOutput,
@@ -189,6 +187,7 @@ class Worker(WorkerBase):
and self.parallel_config.distributed_executor_backend
not in ["ray", "external_launcher"]
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
+ and self.vllm_config.parallel_config.nnodes_within_dp == 1
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
@@ -205,7 +204,14 @@ class Worker(WorkerBase):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
-
+ visible_device_count = (
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
+ )
+ assert self.parallel_config.local_world_size <= visible_device_count, (
+ f"local_world_size ({self.parallel_config.local_world_size}) must be "
+ f"less than or equal to the number of visible devices "
+ f"({visible_device_count})."
+ )
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
@@ -573,18 +579,7 @@ class Worker(WorkerBase):
all_gather_tensors=all_gather_tensors,
)
- kv_connector_output = output.kv_connector_output
- if not kv_connector_output:
- return None
-
- # In case of PP with kv transfer, we need to pass through the
- # kv_connector_output
- if kv_connector_output.is_empty():
- return EMPTY_MODEL_RUNNER_OUTPUT
-
- output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
- output.kv_connector_output = kv_connector_output
- return output
+ return None
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.model_runner.take_draft_token_ids()
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 01490e0dfac9c..e9eb7cad38f88 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -1254,13 +1254,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
- valid_sampled_token_ids = selected_token_ids.tolist()
+ valid_sampled_token_ids: list[np.ndarray] = [
+ row for row in selected_token_ids.numpy()
+ ]
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for i in discard_sampled_tokens_req_indices:
- valid_sampled_token_ids[i].clear()
+ valid_sampled_token_ids[i] = np.array([])
# Append sampled tokens
for i, req_state, seq_len in request_seq_lens:
@@ -1273,7 +1275,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
- seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
+ seq.numpy() for seq in selected_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py
index 9f16b1e6d03ee..be8326e2fdbc1 100644
--- a/vllm/v1/worker/ubatching.py
+++ b/vllm/v1/worker/ubatching.py
@@ -27,8 +27,8 @@ class UBatchContext:
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
- gpu_comm_done_event: torch.cuda.Event,
- gpu_compute_done_event: torch.cuda.Event,
+ gpu_comm_done_event: torch.Event,
+ gpu_compute_done_event: torch.Event,
schedule: str = "default",
):
self.id = id
@@ -207,8 +207,8 @@ def make_ubatch_contexts(
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
- gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
- gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
+ gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
+ gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2
diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py
index 3991c16eefba9..16f321c080779 100644
--- a/vllm/v1/worker/worker_base.py
+++ b/vllm/v1/worker/worker_base.py
@@ -180,6 +180,7 @@ class WorkerWrapperBase:
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
+ global_rank: int | None = None,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
@@ -192,6 +193,7 @@ class WorkerWrapperBase:
group.
"""
self.rpc_rank = rpc_rank
+ self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None
# do not store this `vllm_config`, `init_worker` will set the final
@@ -312,7 +314,7 @@ class WorkerWrapperBase:
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
- kv_cache_config = kv_cache_configs[self.rpc_rank]
+ kv_cache_config = kv_cache_configs[self.global_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py
index 4f82c18da73aa..30563305853a5 100644
--- a/vllm/v1/worker/xpu_model_runner.py
+++ b/vllm/v1/worker/xpu_model_runner.py
@@ -37,19 +37,12 @@ class XPUModelRunner(GPUModelRunner):
@contextmanager
def _torch_cuda_wrapper():
- class _EventPlaceholder:
- def __init__(self, *args, **kwargs) -> None:
- self.record = lambda: None
- self.synchronize = lambda: None
-
try:
# replace cuda APIs with xpu APIs, this should work by default
- torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
yield
finally:
- # if anything goes wrong, just patch it with a placeholder
- torch.cuda.Event = _EventPlaceholder
+ pass