diff --git a/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh new file mode 100644 index 0000000000000..937a43d1a3221 --- /dev/null +++ b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -euxo pipefail + +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +THRESHOLD=${1:-0.25} +NUM_Q=${2:-1319} +PORT=${3:-8040} +OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} +mkdir -p "${OUT_DIR}" + +wait_for_server() { + local port=$1 + timeout 600 bash -c ' + until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do + sleep 1 + done' +} + +MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct" + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then + kill "${SERVER_PID}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${SERVER_PID}" 2>/dev/null || break + sleep 0.5 + done + kill -9 "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +for BACK in "${BACKENDS[@]}"; do + VLLM_DEEP_GEMM_WARMUP=skip \ + VLLM_ALL2ALL_BACKEND=$BACK \ + vllm serve "$MODEL" \ + --enforce-eager \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \ + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \ + --trust-remote-code \ + --max-model-len 2048 \ + --gpu-memory-utilization 0.9 \ + --port $PORT & + SERVER_PID=$! + wait_for_server $PORT + + TAG=$(echo "$MODEL" | tr '/: \\n' '_____') + OUT="${OUT_DIR}/${TAG}_${BACK}.json" + python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT} + python3 - <= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}" +PY + + cleanup + SERVER_PID= + sleep 1 + PORT=$((PORT+1)) +done diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index c7d460be6e2b5..3c9b8cbedcf06 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -61,8 +61,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min - timeout_in_minutes: 20 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 grade: Blocking @@ -73,6 +73,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -82,6 +83,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -759,19 +761,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 - # grade: Blocking - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -1629,7 +1619,6 @@ steps: mirror_hardwares: [amdexperimental] agent_pool: mi325_4 # grade: Blocking - gpu: h100 optional: true num_gpus: 4 working_dir: "/vllm-workspace" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0a5b56f473c29..8e6d32f71f220 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,8 +57,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min - timeout_in_minutes: 20 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 source_file_dependencies: - vllm/ - tests/test_inputs.py @@ -66,6 +66,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -75,6 +76,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -652,7 +654,7 @@ steps: - vllm/model_executor/layers/quantization autorun_on_main: true commands: - - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt - label: OpenAI API correctness # 22min timeout_in_minutes: 30 @@ -672,16 +674,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -692,6 +685,7 @@ steps: source_file_dependencies: - vllm/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Run a subset of model initialization tests - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset @@ -704,6 +698,7 @@ steps: - vllm/model_executor/models/ - vllm/transformers_utils/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Only when vLLM model source is modified - test initialization of a large # subset of supported models (the complement of the small subset in the above @@ -1069,7 +1064,7 @@ steps: - csrc/ - vllm/model_executor/layers/quantization commands: - - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt ##### 1 GPU test ##### ##### multi gpus test ##### @@ -1228,6 +1223,8 @@ steps: # FIXIT: find out which code initialize cuda before running the test # before the fix, we need to use spawn to test it - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # Alot of these tests are on the edge of OOMing + - export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # There is some Tensor Parallelism related processing logic in LoRA that # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py @@ -1346,6 +1343,7 @@ steps: - label: Prime-RL Integration Test # 15min timeout_in_minutes: 30 optional: true + soft_fail: true num_gpus: 2 working_dir: "/vllm-workspace" source_file_dependencies: @@ -1379,4 +1377,4 @@ steps: num_gpus: 2 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml index 072bccadb726a..252af1e56a105 100644 --- a/.buildkite/test_areas/misc.yaml +++ b/.buildkite/test_areas/misc.yaml @@ -115,7 +115,7 @@ steps: - label: Async Engine, Inputs, Utils, Worker, Config (CPU) depends_on: ~ - timeout_in_minutes: 20 + timeout_in_minutes: 30 source_file_dependencies: - vllm/ - tests/test_inputs.py @@ -123,6 +123,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -132,6 +133,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config diff --git a/.buildkite/test_areas/tool_use.yaml b/.buildkite/test_areas/tool_use.yaml index 7040cd1d253b3..69527a1214229 100644 --- a/.buildkite/test_areas/tool_use.yaml +++ b/.buildkite/test_areas/tool_use.yaml @@ -10,14 +10,4 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) - depends_on: ~ - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use diff --git a/README.md b/README.md index 5c040fe4a66d2..26222b815370d 100644 --- a/README.md +++ b/README.md @@ -143,11 +143,13 @@ Compute Resources: - Databricks - DeepInfra - Google Cloud +- IBM - Intel - Lambda Lab - Nebius - Novita AI - NVIDIA +- Red Hat - Replicate - Roblox - RunPod diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 25baa9cbda39c..a245e2022e605 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -18,6 +18,11 @@ MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} +HOSTNAME=$(hostname) +if [[ -z "$HOSTNAME" ]]; then + echo "Error: Failed to determine hostname." >&2 + exit 1 +fi LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" @@ -82,6 +87,7 @@ start_server() { "$MODEL" "--disable-log-requests" "--port" "8004" + "--host" "$HOSTNAME" "--gpu-memory-utilization" "$gpu_memory_utilization" "--max-num-seqs" "$max_num_seqs" "--max-num-batched-tokens" "$max_num_batched_tokens" @@ -113,7 +119,7 @@ start_server() { # since that we should always have permission to send signal to the server process. kill -0 $server_pid 2> /dev/null || break - RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) + RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) if [[ "$STATUS_CODE" -eq 200 ]]; then server_started=1 @@ -173,6 +179,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -188,7 +195,7 @@ run_benchmark() { request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do # clear prefix cache - curl -X POST http://0.0.0.0:8004/reset_prefix_cache + curl -X POST http://${HOSTNAME}:8004/reset_prefix_cache sleep 5 bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" vllm bench serve \ @@ -204,6 +211,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -304,6 +312,7 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 \ --profile &> "$bm_log" else diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index d69d74ca61f54..831b76b66e096 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -620,7 +620,7 @@ def get_tokenizer( kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: - from vllm.tokenizers import MistralTokenizer + from vllm.tokenizers.mistral import MistralTokenizer except ImportError as e: raise ImportError( "MistralTokenizer requires vllm package.\n" diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 2cf3c1a755d3c..0d4f9b7aa07c8 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -35,16 +35,21 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # sm90a set(SUPPORT_ARCHS) -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) - list(APPEND SUPPORT_ARCHS 9.0a) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3) + list(APPEND SUPPORT_ARCHS "9.0a") endif() -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8) - list(APPEND SUPPORT_ARCHS 10.0a) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9) + # CUDA 12.9 has introduced "Family-Specific Architecture Features" + # this supports all compute_10x family + list(APPEND SUPPORT_ARCHS "10.0f") +elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + list(APPEND SUPPORT_ARCHS "10.0a") endif() cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") if(FLASH_MLA_ARCHS) + message(STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS}") set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") @@ -126,7 +131,8 @@ if(FLASH_MLA_ARCHS) $<$:-UPy_LIMITED_API> $<$:-UPy_LIMITED_API>) else() - # Create empty targets for setup.py when not targeting sm90a systems + message(STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS}") + # Create empty targets for setup.py on unsupported systems add_custom_target(_flashmla_C) add_custom_target(_flashmla_extension_C) endif() diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 5fa367abd96f5..7229e420d3fe4 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) { template __device__ inline T apply_scoring(T val) { - if constexpr (SF == SCORING_SIGMOID) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { return apply_sigmoid(val); } else { + static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID, + "Unsupported ScoringFunc in apply_scoring"); return val; } } @@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { + float scale = routed_scaling_factor; + if (renormalize) { + scale /= topk_sum; + } for (int i = lane_id; i < topk; i += WARP_SIZE) { float base = cuda_cast(s_topk_value[i]); - float value = renormalize ? (base / topk_sum * routed_scaling_factor) - : (base * routed_scaling_factor); + float value = base * scale; topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index f9ac874c43730..49d1b2086b8db 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) { return val; } +template +__device__ __forceinline__ float ComputeGroupScale( + const T* __restrict__ group_input, T* __restrict__ smem_group, + const int group_size, const int lane_id, const int threads_per_group, + const float eps, const float max_8bit) { + float local_absmax = eps; + + constexpr int vec_size = 16 / sizeof(T); + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + return y_s; +} + +template +__device__ __forceinline__ void QuantizeGroup( + const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output, + const int group_size, const int lane_id, const int threads_per_group, + const float y_s, const float min_8bit, const float max_8bit) { + constexpr int vec_size = 16 / sizeof(T); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + template __global__ void per_token_group_quant_8bit_kernel( @@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel( const int64_t global_group_id = block_group_id + local_group_id; const int64_t block_group_offset = global_group_id * group_size; - float local_absmax = eps; - using scale_element_t = float; static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); @@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel( T* smem = reinterpret_cast(smem_raw); T* smem_group = smem + local_group_id * group_size; - constexpr int vec_size = 16 / sizeof(T); - using vec_t = vllm::vec_n_t; - - // copy global -> shared & compute absmax - auto scalar_op_cache = [&] __device__(T & dst, const T& src) { - float abs_v = fabsf(static_cast(src)); - local_absmax = fmaxf(local_absmax, abs_v); - dst = src; - }; - - vllm::vectorize_with_alignment( - group_input, // in - smem_group, // out (shared) - group_size, // elements per group - lane_id, // thread id - threads_per_group, // stride in group - scalar_op_cache); // scalar handler - - local_absmax = GroupReduceMax(local_absmax); - - float y_s = local_absmax / max_8bit; - if constexpr (SCALE_UE8M0) { - y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); - } + const float y_s = ComputeGroupScale( + group_input, smem_group, group_size, lane_id, threads_per_group, eps, + max_8bit); scale_element_t y_s_quant = y_s; @@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel( __syncthreads(); - // quantize shared -> global 8-bit - auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { - float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); - dst = DST_DTYPE(q); - }; + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} - vllm::vectorize_with_alignment( - smem_group, // in (shared) - group_output, // out (global quant tensor) - group_size, // elements - lane_id, // tid - threads_per_group, // stride - scalar_op_quant); // scalar handler +inline int GetGroupsPerBlock(int64_t num_groups) { + if (num_groups % 16 == 0) { + return 16; + } + if (num_groups % 8 == 0) { + return 8; + } + if (num_groups % 4 == 0) { + return 4; + } + if (num_groups % 2 == 0) { + return 2; + } + return 1; } void per_token_group_quant_8bit(const torch::Tensor& input, @@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, constexpr int THREADS_PER_GROUP = 16; - int groups_per_block = 1; - - if (num_groups % 16 == 0) { - groups_per_block = 16; - } else if (num_groups % 8 == 0) { - groups_per_block = 8; - } else if (num_groups % 4 == 0) { - groups_per_block = 4; - } else if (num_groups % 2 == 0) { - groups_per_block = 2; - } + const int groups_per_block = GetGroupsPerBlock(num_groups); auto dst_type = output_q.scalar_type(); const int num_blocks = num_groups / groups_per_block; @@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel( const int64_t block_group_offset = global_group_id * group_size; - float local_absmax = eps; - const T* group_input = input + block_group_offset; DST_DTYPE* group_output = static_cast(output_q) + block_group_offset; @@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel( extern __shared__ __align__(16) char smem_raw[]; T* smem = reinterpret_cast(smem_raw); T* smem_group = smem + local_group_id * group_size; - - constexpr int vec_size = 16 / sizeof(T); - using vec_t = vllm::vec_n_t; - - // copy global -> shared & compute absmax - auto scalar_op_cache = [&] __device__(T & dst, const T& src) { - float abs_v = fabsf(static_cast(src)); - local_absmax = fmaxf(local_absmax, abs_v); - dst = src; - }; - - vllm::vectorize_with_alignment( - group_input, // in - smem_group, // out (shared) - group_size, // elements per group - lane_id, // thread id - threads_per_group, // stride in group - scalar_op_cache); // scalar handler - - local_absmax = GroupReduceMax(local_absmax); - - float y_s = local_absmax / max_8bit; - y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + const float y_s = + ComputeGroupScale(group_input, smem_group, group_size, lane_id, + threads_per_group, eps, max_8bit); // pack 4 scales into a uint32 if (lane_id == 0) { @@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel( __syncthreads(); - // quantize shared -> global 8-bit - auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { - float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); - dst = DST_DTYPE(q); - }; - - vllm::vectorize_with_alignment( - smem_group, // in (shared) - group_output, // out (global quant tensor) - group_size, // elements - lane_id, // tid - threads_per_group, // stride - scalar_op_quant); // scalar handler + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); } void per_token_group_quant_8bit_packed(const torch::Tensor& input, @@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input, constexpr int THREADS_PER_GROUP = 16; - int groups_per_block = 1; - - if (num_groups % 16 == 0) { - groups_per_block = 16; - } else if (num_groups % 8 == 0) { - groups_per_block = 8; - } else if (num_groups % 4 == 0) { - groups_per_block = 4; - } else if (num_groups % 2 == 0) { - groups_per_block = 2; - } + const int groups_per_block = GetGroupsPerBlock(num_groups); auto dst_type = output_q.scalar_type(); const int num_blocks = num_groups / groups_per_block; diff --git a/docker/Dockerfile b/docker/Dockerfile index 0d50d97e54c6c..ae2624ace67b9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,7 +32,7 @@ ARG DEADSNAKES_GPGKEY_URL # The PyPA get-pip.py script is a self contained script+zip file, that provides # both the installer script and the pip base85-encoded zip archive. This allows -# bootstrapping pip in environment where a dsitribution package does not exist. +# bootstrapping pip in environment where a distribution package does not exist. # # By parameterizing the URL for get-pip.py installation script, we allow # third-party to use their own copy of the script stored in a private mirror. @@ -73,15 +73,13 @@ ARG INSTALL_KV_CONNECTORS=false #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM ${BUILD_BASE_IMAGE} AS base + ARG CUDA_VERSION ARG PYTHON_VERSION -ARG TARGETPLATFORM -ARG INSTALL_KV_CONNECTORS=false + ENV DEBIAN_FRONTEND=noninteractive -ARG GET_PIP_URL - -# Install system dependencies and uv, then create Python virtual environment +# Install system dependencies including build tools RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ @@ -107,32 +105,30 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && ln -s /opt/venv/bin/pip /usr/bin/pip \ && python3 --version && python3 -m pip --version -ARG PIP_INDEX_URL UV_INDEX_URL -ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL -ARG PYTORCH_CUDA_INDEX_BASE_URL -ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER - # Activate virtual environment and add uv to PATH ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH" ENV VIRTUAL_ENV="/opt/venv" -# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out -# Reference: https://github.com/astral-sh/uv/pull/1694 +# Environment for uv ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" -# Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy -RUN <> /etc/environment -# Install Python and other dependencies +# Install Python and system dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ @@ -408,63 +421,104 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version -# Install CUDA development tools and build essentials for runtime JIT compilation +# Install CUDA development tools for runtime JIT compilation # (FlashInfer, DeepGEMM, EP kernels all require compilation at runtime) RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-') && \ apt-get update -y && \ apt-get install -y --no-install-recommends \ - cuda-nvcc-${CUDA_VERSION_DASH} \ - cuda-cudart-${CUDA_VERSION_DASH} \ - cuda-nvrtc-${CUDA_VERSION_DASH} \ - cuda-cuobjdump-${CUDA_VERSION_DASH} \ - # https://github.com/vllm-project/vllm/issues/29590 - libcurand-dev-${CUDA_VERSION_DASH} \ - libcublas-${CUDA_VERSION_DASH} \ - # Fixes nccl_allocator requiring nccl.h at runtime - # https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22 - libnccl-dev && \ + cuda-nvcc-${CUDA_VERSION_DASH} \ + cuda-cudart-${CUDA_VERSION_DASH} \ + cuda-nvrtc-${CUDA_VERSION_DASH} \ + cuda-cuobjdump-${CUDA_VERSION_DASH} \ + libcurand-dev-${CUDA_VERSION_DASH} \ + libcublas-${CUDA_VERSION_DASH} \ + # Fixes nccl_allocator requiring nccl.h at runtime + # https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22 + libnccl-dev && \ rm -rf /var/lib/apt/lists/* +# Install uv for faster pip installs +RUN python3 -m pip install uv + +# Environment for uv +ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" +ENV UV_LINK_MODE=copy + +# Workaround for triton/pytorch issues +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ + +# ============================================================ +# SLOW-CHANGING DEPENDENCIES BELOW +# These are the expensive layers that we want to cache +# ============================================================ + +# Install PyTorch and core CUDA dependencies +# This is ~2GB and rarely changes +ARG PYTORCH_CUDA_INDEX_BASE_URL +COPY requirements/common.txt /tmp/common.txt +COPY requirements/cuda.txt /tmp/requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r /tmp/requirements-cuda.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ + rm /tmp/requirements-cuda.txt /tmp/common.txt + +# Install FlashInfer pre-compiled kernel cache and binaries +# This is ~1.1GB and only changes when FlashInfer version bumps +# https://docs.flashinfer.ai/installation.html +ARG FLASHINFER_VERSION=0.5.3 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \ + && uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ + --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + && flashinfer show-config + +# ============================================================ +# OPENAI API SERVER DEPENDENCIES +# Pre-install these to avoid reinstalling on every vLLM wheel rebuild +# ============================================================ + +# Install gdrcopy (saves ~6s per build) +# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment +ARG GDRCOPY_CUDA_VERSION=12.8 +ARG GDRCOPY_OS_VERSION=Ubuntu22_04 +ARG TARGETPLATFORM +COPY tools/install_gdrcopy.sh /tmp/install_gdrcopy.sh +RUN set -eux; \ + case "${TARGETPLATFORM}" in \ + linux/arm64) UUARCH="aarch64" ;; \ + linux/amd64) UUARCH="x64" ;; \ + *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ + esac; \ + /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}" && \ + rm /tmp/install_gdrcopy.sh + +# Install vllm-openai dependencies (saves ~2.6s per build) +# These are stable packages that don't depend on vLLM itself +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + BITSANDBYTES_VERSION="0.42.0"; \ + else \ + BITSANDBYTES_VERSION="0.46.1"; \ + fi; \ + uv pip install --system accelerate hf_transfer modelscope \ + "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3' + +# ============================================================ +# VLLM INSTALLATION (depends on build stage) +# ============================================================ + ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL ARG PYTORCH_CUDA_INDEX_BASE_URL ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER -# Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/uv \ - python3 -m pip install uv - -# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out -# Reference: https://github.com/astral-sh/uv/pull/1694 -ENV UV_HTTP_TIMEOUT=500 -ENV UV_INDEX_STRATEGY="unsafe-best-match" -# Use copy mode to avoid hardlink failures with Docker cache mounts -ENV UV_LINK_MODE=copy - -# Workaround for https://github.com/openai/triton/issues/2507 and -# https://github.com/pytorch/pytorch/issues/107960 -- hopefully -# this won't be needed for future versions of this docker image -# or future versions of triton. -RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ - # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -# Install FlashInfer pre-compiled kernel cache and binaries -# https://docs.flashinfer.ai/installation.html -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==0.5.3 \ - && uv pip install --system flashinfer-jit-cache==0.5.3 \ - --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - && flashinfer show-config - -COPY examples examples -COPY benchmarks benchmarks -COPY ./vllm/collect_env.py . - RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ uv pip list @@ -478,7 +532,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ echo "No DeepGEMM wheels to install; skipping."; \ fi' -# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH (https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/.ci/manywheel/build_cuda.sh#L141C14-L141C36) +# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage @@ -487,23 +541,17 @@ RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm uv pip install --system ep_kernels/dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -RUN --mount=type=bind,source=tools/install_gdrcopy.sh,target=/tmp/install_gdrcopy.sh,ro \ - set -eux; \ - case "${TARGETPLATFORM}" in \ - linux/arm64) UUARCH="aarch64" ;; \ - linux/amd64) UUARCH="x64" ;; \ - *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ - esac; \ - /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}" - # CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will # return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers # consistently from the host (see https://github.com/vllm-project/vllm/issues/18859). # Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override. ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} +# Copy examples and benchmarks at the end to minimize cache invalidation +COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . #################### vLLM installation IMAGE #################### - #################### TEST IMAGE #################### # image to run unit testing suite # note that this uses vllm installed by `pip` @@ -569,18 +617,12 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 -# install additional dependencies for openai api server +# install kv_connectors if requested RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \ if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \ uv pip install --system -r /tmp/kv_connectors.txt; \ - fi; \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - BITSANDBYTES_VERSION="0.42.0"; \ - else \ - BITSANDBYTES_VERSION="0.46.1"; \ - fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3' + fi ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index adac43c6accbe..72d2053102c22 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils ENV NIXL_VERSION=0.7.0 RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts +RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf + # remove torch bundled oneccl to avoid conflicts RUN --mount=type=cache,target=/root/.cache/pip \ pip uninstall oneccl oneccl-devel -y diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index 7420ca4d89441..c8839eb93de95 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index fd1c82376d086..847b99cce45c9 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -24,11 +24,13 @@ Compute Resources: - Databricks - DeepInfra - Google Cloud +- IBM - Intel - Lambda Lab - Nebius - Novita AI - NVIDIA +- Red Hat - Replicate - Roblox - RunPod diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index fdd9c317b022f..556d9f8b9420a 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -7,7 +7,7 @@ This guide covers optimization strategies and performance tuning for vLLM V1. ## Preemption -Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +Due to the autoregressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes available again. When this occurs, you may see the following warning: diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index 0e636c87f38a4..d70e0142e3202 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -82,7 +82,7 @@ DOCKER_BUILDKIT=1 docker build . \ ## Building for Arm64/aarch64 -A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64. +A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper and Grace-Blackwell. Using the flag `--platform "linux/arm64"` will build for arm64. !!! note Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=` @@ -104,6 +104,25 @@ A docker container can be built for aarch64 systems such as the Nvidia Grace-Hop --build-arg RUN_WHEEL_CHECK=false ``` +For (G)B300, we recommend using CUDA 13, as shown in the following command. + +??? console "Command" + + ```bash + DOCKER_BUILDKIT=1 docker build \ + --build-arg CUDA_VERSION=13.0.1 \ + --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 \ + --build-arg max_jobs=256 \ + --build-arg nvcc_threads=2 \ + --build-arg RUN_WHEEL_CHECK=false \ + --build-arg torch_cuda_arch_list='9.0 10.0+PTX' \ + --platform "linux/arm64" \ + --tag vllm/vllm-gb300-openai:latest \ + --target vllm-openai \ + -f docker/Dockerfile \ + . + ``` + !!! note If you are building the `linux/arm64` image on a non-ARM host (e.g., an x86_64 machine), you need to ensure your system is set up for cross-compilation using QEMU. This allows your host machine to emulate ARM64 execution. diff --git a/docs/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md index 2f1894ccf0022..624e98a08c98d 100644 --- a/docs/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -4,7 +4,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le * **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code. * **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards. -* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others. +* **High performance** – Optimized for LLM workloads with features like multimodel support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others. If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**! diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 7baadf8ba23cb..19c02fc88641c 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -41,7 +41,7 @@ These features allow the most flexibility for cudagraph capture and compilation * `NONE` — turn CUDA Graphs off. Good for debugging. * `PIECEWISE` — a single-mode strategy (and past default). It is the most flexible: attention or other CUDA Graphs-incompatible operations stay eager, everything else goes into CUDA Graphs. Requires piecewise compilation. * `FULL` — a single-mode strategy, which only captures full CUDA Graphs for non-uniform batches, then uniform-decode batches reuse the CUDA Graph of non-uniform batch of the same batch_size, since they are compatible; can be good for small models or workloads with small prompts. -* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs. +* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc.; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs. * `FULL_AND_PIECEWISE` — (default mode) full CUDA Graph for uniform decode, piecewise CUDA Graphs for others; generally the most performant setting, especially for low latency with small models or MoEs, but also requires the most memory and takes the longest to capture. Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_PIECEWISE` for better performance, (for pooling models, it's still `PIECEWISE`). Otherwise, e.g. if piecewise compilation unavailable, we default to `NONE`. @@ -49,7 +49,7 @@ Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_ While `NONE` , `PIECEWISE`, and `FULL` are single-mode configurations and simply equivalent to past implementations of eager execution, piecewise CUDA Graphs, and full CUDA Graphs respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual-mode configurations, which require dispatching to switch between concrete runtime modes according to runtime batches dynamically. !!! note - Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potantial `NONE` if no suitable CUDA Graph available), depending on the batch composition. + Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potential `NONE` if no suitable CUDA Graph available), depending on the batch composition. While cascade attention is not cudagraph compatible, it is now compatible with all possible cudagraph mode configurations. If a batch uses cascade attention, it always gets dispatched to `PIECEWISE` mode if available (otherwise `NONE`). diff --git a/docs/design/optimization_levels.md b/docs/design/optimization_levels.md index 940286071ef3c..4987c1820ad32 100644 --- a/docs/design/optimization_levels.md +++ b/docs/design/optimization_levels.md @@ -4,7 +4,7 @@ ## Overview -vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten. +vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechanism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out-of-the-box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten. ## Level Summaries and Usage Examples ```bash diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index d87b2a639df12..5cc5878425515 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -36,7 +36,7 @@ the input pointers `q`, `k_cache`, and `v_cache`, which point to query, key, and value data on global memory that need to be read and processed. The output pointer `out` points to global memory where the result should be written. These four pointers actually -refer to multi-dimensional arrays, but each thread only accesses the +refer to multidimensional arrays, but each thread only accesses the portion of data assigned to it. I have omitted all other runtime parameters here for simplicity. @@ -229,7 +229,7 @@ manner. ## QK -As shown the pseudo code below, before the entire for loop block, we +As shown the pseudocode below, before the entire for loop block, we fetch the query data for one token and store it in `q_vecs`. Then, in the outer for loop, we iterate through different `k_ptrs` that point to different tokens and prepare the `k_vecs` in the inner for @@ -403,7 +403,7 @@ for ... { // Iteration over different blocks. } ``` -As shown in the above pseudo code, in the outer loop, similar to +As shown in the above pseudocode, in the outer loop, similar to `k_ptr`, `logits_vec` iterates over different blocks and reads `V_VEC_SIZE` elements from `logits`. In the inner loop, each thread reads `V_VEC_SIZE` elements from the same tokens as a diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index c77fe44659790..70a11d6def566 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -420,7 +420,7 @@ Flags: `--tool-call-parser pythonic --chat-template {see_above}` ## How to Write a Tool Parser Plugin -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py](../../vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py). +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/tool_parsers/hermes_tool_parser.py](../../vllm/tool_parsers/hermes_tool_parser.py). Here is a summary of a plugin file: @@ -468,7 +468,7 @@ Here is a summary of a plugin file: # register the tool parser to ToolParserManager ToolParserManager.register_lazy_module( name="example", - module_path="vllm.entrypoints.openai.tool_parsers.example", + module_path="vllm.tool_parsers.example", class_name="ExampleToolParser", ) diff --git a/docs/getting_started/installation/cpu.arm.inc.md b/docs/getting_started/installation/cpu.arm.inc.md index ad9c7d9ef21be..657bf2509db01 100644 --- a/docs/getting_started/installation/cpu.arm.inc.md +++ b/docs/getting_started/installation/cpu.arm.inc.md @@ -16,15 +16,15 @@ vLLM offers basic model inferencing and serving on Arm CPU platform, with suppor # --8<-- [start:pre-built-wheels] Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries. -Please replace `` in the commands below with a specific version string (e.g., `0.11.2`). ```bash -uv pip install --pre vllm==+cpu --extra-index-url https://wheels.vllm.ai/%2Bcpu/ +export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//') +uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu ``` ??? console "pip" ```bash - pip install --pre vllm==+cpu --extra-index-url https://wheels.vllm.ai/%2Bcpu/ + pip install vllm==${VLLM_VERSION}+cpu --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu ``` The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. @@ -35,20 +35,28 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe * `https://wheels.vllm.ai/nightly/cpu/vllm` -To install from nightly index, copy the link address of the `*.whl` under this index to run, for example: - +To install from nightly index, run: ```bash -uv pip install -U https://wheels.vllm.ai/c756fb678184b867ed94e5613a529198f1aee423/vllm-0.13.0rc2.dev11%2Bgc756fb678.cpu-cp38-abi3-manylinux_2_31_aarch64.whl # current nightly build (the filename will change!) +uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu ``` +??? console "pip (there's a caveat)" + + Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). + + If you insist on using `pip`, you have to specify the full URL (link address) of the wheel file (which can be obtained from https://wheels.vllm.ai/nightly/cpu/vllm). + + ```bash + pip install https://wheels.vllm.ai/4fa7ce46f31cbd97b4651694caf9991cc395a259/vllm-0.13.0rc2.dev104%2Bg4fa7ce46f.cpu-cp38-abi3-manylinux_2_35_aarch64.whl # current nightly build (the filename will change!) + ``` + **Install specific revisions** -If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), specify the full commit hash in the index: -https://wheels.vllm.ai/${VLLM_COMMIT}/cpu/vllm . -Then, copy the link address of the `*.whl` under this index to run: +If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: ```bash -uv pip install -U +export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch +uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu ``` # --8<-- [end:pre-built-wheels] @@ -103,10 +111,10 @@ Testing has been conducted on AWS Graviton3 instances for compatibility. See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image. Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo). -Please replace `` in the command below with a specific version string (e.g., `0.12.0`). ```bash -docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v +export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//') +docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v${VLLM_VERSION} ``` You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days. diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 94920dc5306b3..e3974354d8f3b 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -281,17 +281,27 @@ Alternatively, you can use the `openai` Python package: Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications. -If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: +If desired, you can also manually set the backend of your choice using the `--attention-backend` CLI argument: + +```bash +# For online serving +vllm serve Qwen/Qwen2.5-1.5B-Instruct --attention-backend FLASH_ATTN + +# For offline inference +python script.py --attention-backend FLASHINFER +``` + +Some of the available backend options include: - On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`. - On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`. -For AMD ROCm, you can further control the specific Attention implementation using the following variables: +For AMD ROCm, you can further control the specific Attention implementation using the following options: -- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` -- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` -- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` -- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1` +- Triton Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=0 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. +- AITER Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. +- Triton Prefill-Decode Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=true` as a CLI argument. +- AITER Multi-head Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. !!! warning There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it. diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 586d5d91634dc..9ba0f4ca9096e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -659,7 +659,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|----------------------|---------------------------| | `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | +| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A+ | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ | | `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ | +| `BagelForConditionalGeneration` | BAGEL | T + I+ | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | `BeeForConditionalGeneration` | Bee-8B | T + IE+ | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ | | `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | @@ -743,7 +745,7 @@ Some models are supported only via the [Transformers modeling backend](#transfor - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. !!! note - For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. diff --git a/docs/serving/data_parallel_deployment.md b/docs/serving/data_parallel_deployment.md index e5954917cd790..f0946eaf407a9 100644 --- a/docs/serving/data_parallel_deployment.md +++ b/docs/serving/data_parallel_deployment.md @@ -8,11 +8,11 @@ For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Lat In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks. -The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). +By default, expert layers form a tensor parallel group of size `DP × TP`. To use expert parallelism instead, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). See [Expert Parallel Deployment](expert_parallel_deployment.md) for details on how attention and expert layers behave differently with EP enabled. In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size. -For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP). +For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form a group of size `DP × TP` (using either tensor parallelism by default, or expert parallelism if `--enable-expert-parallel` is set). In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently. diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 923020dc88c91..82fde27d71fd4 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -44,7 +44,27 @@ Where: - `DP_SIZE`: Data parallel size - `EP_SIZE`: Expert parallel size (computed automatically) -When EP is enabled, MoE layers use expert parallelism instead of tensor parallelism, while attention layers continue to use tensor parallelism if `TP_SIZE > 1`. +### Layer Behavior with EP Enabled + +When EP is enabled, different layers in MoE models behave differently: + +| Layer Type | Behavior | Parallelism Used | +|------------|----------|------------------| +| **Expert (MoE) Layers** | Sharded across all EP ranks | Expert Parallel (EP) of size `TP × DP` | +| **Attention Layers** | Behavior depends on TP size | See below | + +**Attention layer parallelism:** + +- **When `TP = 1`**: Attention weights are **replicated** across all DP ranks (data parallelism) +- **When `TP > 1`**: Attention weights are **sharded** using tensor parallelism across TP ranks within each DP group + +For example, with `TP=2, DP=4` (8 GPUs total): + +- Expert layers form an EP group of size 8, with experts distributed across all GPUs +- Attention layers use TP=2 within each of the 4 DP groups + +!!! note "Key Difference from Data Parallel Deployment" + Without `--enable-expert-parallel`, MoE layers would use tensor parallelism (forming a TP group of size `TP × DP`), similar to dense models. With EP enabled, expert layers switch to expert parallelism, which can provide better efficiency and locality for MoE models. ### Example Command diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index a32840ea73b9a..ed93432701f35 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -62,7 +62,7 @@ If a single node lacks sufficient GPUs to hold the model, deploy vLLM across mul ### What is Ray? -Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments require Ray as the runtime engine. +Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments can use Ray as the runtime engine. vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. @@ -130,9 +130,31 @@ vllm serve /path/to/the/model/in/the/container \ --distributed-executor-backend ray ``` +### Running vLLM with MultiProcessing + +Besides Ray, Multi-node vLLM deployments can also use `multiprocessing` as the runtime engine. Here's an example to deploy model across 2 nodes (8 GPUs per node) with `tp_size=8` and `pp_size=2`. + +Choose one node as the head node and run: + +```bash +vllm serve /path/to/the/model/in/the/container \ + --tensor-parallel-size 8 --pipeline-parallel-size 2 \ + --nnodes 2 --node-rank 0 \ + --master-addr +``` + +On the other worker node, run: + +```bash +vllm serve /path/to/the/model/in/the/container \ + --tensor-parallel-size 8 --pipeline-parallel-size 2 \ + --nnodes 2 --node-rank 1 \ + --master-addr --headless +``` + ## Optimizing network communication for tensor parallelism -Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. +Efficient tensor parallelism requires fast internode communication, preferably through high-speed network adapters such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the [examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) helper script. Contact your system administrator for more information about the required flags. diff --git a/docs/usage/security.md b/docs/usage/security.md index 74060d86f6854..e619eec660aee 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -10,7 +10,7 @@ All communications between nodes in a multi-node vLLM deployment are **insecure ### Configuration Options for Inter-Node Communications -The following options control inter-node communications in vLLM: +The following options control internode communications in vLLM: #### 1. **Environment Variables:** @@ -28,7 +28,7 @@ The following options control inter-node communications in vLLM: ### Notes on PyTorch Distributed -vLLM uses PyTorch's distributed features for some inter-node communication. For +vLLM uses PyTorch's distributed features for some internode communication. For detailed information about PyTorch Distributed security considerations, please refer to the [PyTorch Security Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features). diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 40462c78ae8c2..a6d0c5d12dd41 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. -# Voxtral -# Make sure to install mistral-common[audio]. -def run_voxtral(question: str, audio_count: int) -> ModelRequestData: - from mistral_common.audio import Audio - from mistral_common.protocol.instruct.chunk import ( - AudioChunk, - RawAudio, - TextChunk, - ) - from mistral_common.protocol.instruct.messages import ( - UserMessage, - ) - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer - - model_name = "mistralai/Voxtral-Mini-3B-2507" - tokenizer = MistralTokenizer.from_hf_hub(model_name) - +# AudioFlamingo3 +def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData: + model_name = "nvidia/audio-flamingo-3-hf" engine_args = EngineArgs( model=model_name, - max_model_len=8192, + max_model_len=4096, max_num_seqs=2, limit_mm_per_prompt={"audio": audio_count}, - config_format="mistral", - load_format="mistral", - tokenizer_mode="mistral", enforce_eager=True, - enable_chunked_prefill=False, ) - text_chunk = TextChunk(text=question) - audios = [ - Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) - for i in range(audio_count) - ] - audio_chunks = [ - AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios - ] + # AudioFlamingo3 uses token for audio + audio_placeholder = "" * audio_count - messages = [UserMessage(content=[*audio_chunks, text_chunk])] - - req = ChatCompletionRequest(messages=messages, model=model_name) - - tokens = tokenizer.encode_chat_completion(req) - prompt_ids, audios = tokens.tokens, tokens.audios - - audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios] - - multi_modal_data = {"audio": audios_and_sr} + prompt = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_placeholder}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) return ModelRequestData( engine_args=engine_args, - prompt_token_ids=prompt_ids, - multi_modal_data=multi_modal_data, + prompt=prompt, ) @@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: ) +# Voxtral +# Make sure to install mistral-common[audio]. +def run_voxtral(question: str, audio_count: int) -> ModelRequestData: + from mistral_common.audio import Audio + from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + RawAudio, + TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( + UserMessage, + ) + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + + model_name = "mistralai/Voxtral-Mini-3B-2507" + tokenizer = MistralTokenizer.from_hf_hub(model_name) + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + config_format="mistral", + load_format="mistral", + tokenizer_mode="mistral", + enforce_eager=True, + enable_chunked_prefill=False, + ) + + text_chunk = TextChunk(text=question) + audios = [ + Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) + for i in range(audio_count) + ] + audio_chunks = [ + AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios + ] + + messages = [UserMessage(content=[*audio_chunks, text_chunk])] + + req = ChatCompletionRequest(messages=messages, model=model_name) + + tokens = tokenizer.encode_chat_completion(req) + prompt_ids, audios = tokens.tokens, tokens.audios + + audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios] + + multi_modal_data = {"audio": audios_and_sr} + + return ModelRequestData( + engine_args=engine_args, + prompt_token_ids=prompt_ids, + multi_modal_data=multi_modal_data, + ) + + # Whisper def run_whisper(question: str, audio_count: int) -> ModelRequestData: assert audio_count == 1, "Whisper only support single audio input per prompt" @@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { - "voxtral": run_voxtral, + "audioflamingo3": run_audioflamingo3, "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, "midashenglm": run_midashenglm, @@ -392,6 +420,7 @@ model_example_map = { "qwen2_audio": run_qwen2_audio, "qwen2_5_omni": run_qwen2_5_omni, "ultravox": run_ultravox, + "voxtral": run_voxtral, "whisper": run_whisper, } diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 9142279140e56..dd5b22ae9b0f6 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData: ) +def run_bagel(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "ByteDance-Seed/BAGEL-7B-MoT" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + ) + + prompts = [ + ( + f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # BLIP-2 def run_blip2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1832,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: model_example_map = { "aria": run_aria, "aya_vision": run_aya_vision, + "bagel": run_bagel, "bee": run_bee, "blip-2": run_blip2, "chameleon": run_chameleon, diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py index ff473d044e323..2599c951ef8ad 100644 --- a/examples/online_serving/structured_outputs/structured_outputs.py +++ b/examples/online_serving/structured_outputs/structured_outputs.py @@ -112,7 +112,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { "messages": [ { "role": "user", - "content": "Generate an SQL query to show the 'username' and 'email'from the 'users' table.", + "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", } ], "extra_body": { diff --git a/tests/benchmarks/test_param_sweep.py b/tests/benchmarks/test_param_sweep.py index 0d47cfd9d6230..467797d9915c9 100644 --- a/tests/benchmarks/test_param_sweep.py +++ b/tests/benchmarks/test_param_sweep.py @@ -23,14 +23,6 @@ class TestParameterSweepItem: {"compilation_config.use_inductor_graph_partition": True}, "--compilation-config.use_inductor_graph_partition=true", ), - ( - {"compilation_config.use_inductor": False}, - "--compilation-config.use_inductor=false", - ), - ( - {"compilation_config.use_inductor": True}, - "--compilation-config.use_inductor=true", - ), ], ) def test_nested_boolean_params(self, input_dict, expected): diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 5379b5157b811..bd326f1157d8f 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -20,13 +20,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer from ...utils import flat_product, multi_gpu_test -is_blackwell = lambda: current_platform.is_device_capability(100) +is_blackwell = lambda: current_platform.is_device_capability_family(100) """Are we running on Blackwell, a lot of tests depend on it""" class Matches(NamedTuple): attention_fusion: int = 0 allreduce_fusion: int = 0 + rms_quant_norm_fusion: int = 0 sequence_parallel: int = 0 async_tp: int = 0 @@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple): MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS_GROUP_FP8: list[ModelBackendTestCase] = [] MODELS: list[ModelBackendTestCase] = [] # tp-only if current_platform.is_cuda(): @@ -498,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg compilation_config.compile_ranges_split_points = ( llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points ) + + +if current_platform.is_cuda(): + MODELS_GROUP_FP8 = [ + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=AttentionBackendEnum.TRITON_ATTN, + matches=Matches( + rms_quant_norm_fusion=48, + ), + ), + ] + +CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Test rms norm+group quant_fp8 fusion + list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +def test_rms_group_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: AttentionBackendEnum, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + log_matches = re.findall( + r"\[fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.rms_quant_norm_fusion diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index bc3dbf5533312..9ccb363b088f5 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -36,7 +36,7 @@ def get_test_models(): DynamicShapesType.BACKED_SIZE_OBLIVIOUS, ], ) -@pytest.mark.parametrize("use_aot_compile", ["0"]) +@pytest.mark.parametrize("use_aot_compile", ["0", "1"]) @pytest.mark.parametrize("use_bytecode_hook", [True, False]) @pytest.mark.parametrize("evaluate_guards", [False, True]) @pytest.mark.skipif( @@ -54,6 +54,12 @@ def test_dynamic_shapes_compilation( if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED: pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0") + if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED: + pytest.skip("unbacked dynamic shapes do not add guards") + + if evaluate_guards and use_aot_compile: + pytest.skip("evaluate_guards requires use_aot_compile=0") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile) monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") @@ -120,7 +126,7 @@ def test_model_specialization_with_evaluate_guards( and dynamic_shapes_type == DynamicShapesType.BACKED and evaluate_guards ): - pytest.skip("evaluate_guards for backed does not work with aot_compile =1") + pytest.skip("evaluate_guards for backed does not work with aot_compile=1") @support_torch_compile class ModelWithSizeCheck(torch.nn.Module): diff --git a/tests/conftest.py b/tests/conftest.py index b21cfd5ba85c4..a03f40a9a72ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -702,10 +702,16 @@ class HfRunner: **kwargs, ) + # Encoder-decoder models return decoder_hidden_states instead of + # hidden_states + hidden_states = ( + getattr(output, "hidden_states", None) or output.decoder_hidden_states + ) + ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) + ) = self._hidden_states_to_logprobs(hidden_states, num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index 102eeaf614410..b194e9b74d874 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -80,10 +80,9 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: return dict(engine_prompt), {} async def _fake_preprocess_chat(*args, **kwargs): - # return conversation, request_prompts, engine_prompts + # return conversation, engine_prompts return ( [{"role": "user", "content": "Test"}], - [[1, 2, 3]], [{"prompt_token_ids": [1, 2, 3]}], ) diff --git a/tests/entrypoints/openai/test_response_api_parsable_context.py b/tests/entrypoints/openai/test_response_api_parsable_context.py index 1899c5f04fe3f..6d97602f32475 100644 --- a/tests/entrypoints/openai/test_response_api_parsable_context.py +++ b/tests/entrypoints/openai/test_response_api_parsable_context.py @@ -165,6 +165,7 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str): model=model_name, input="What is 13 * 24? Use python to calculate the result.", tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + extra_body={"enable_response_messages": True}, temperature=0.0, ) @@ -178,3 +179,8 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str): # make sure the correct math is in the final output assert response.output[3].type == "message" assert "312" in response.output[3].content[0].text + + # test raw input_messages / output_messages + assert len(response.input_messages) == 1 + assert len(response.output_messages) == 3 + assert "312" in response.output_messages[2]["message"] diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py index aee03199bc6f4..02e06297f3987 100644 --- a/tests/entrypoints/openai/test_response_api_simple.py +++ b/tests/entrypoints/openai/test_response_api_simple.py @@ -87,3 +87,48 @@ async def test_reasoning_item(client: OpenAI, model_name: str): assert response.output[0].type == "reasoning" assert response.output[1].type == "message" assert type(response.output[1].content[0].text) is str + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_streaming_output_consistency(client: OpenAI, model_name: str): + """Test that streaming delta text matches the final response output_text. + + This test verifies that when using streaming mode: + 1. The concatenated text from all 'response.output_text.delta' events + 2. Matches the 'output_text' in the final 'response.completed' event + """ + response = await client.responses.create( + model=model_name, + input="Say hello in one sentence.", + stream=True, + ) + + events = [] + async for event in response: + events.append(event) + + assert len(events) > 0 + + # Concatenate all delta text from streaming events + streaming_text = "".join( + event.delta for event in events if event.type == "response.output_text.delta" + ) + + # Get the final response from the last event + response_completed_event = events[-1] + assert response_completed_event.type == "response.completed" + assert response_completed_event.response.status == "completed" + + # Get output_text from the final response + final_output_text = response_completed_event.response.output_text + + # Verify final response has output + assert len(response_completed_event.response.output) > 0 + + # Verify streaming text matches final output_text + assert streaming_text == final_output_text, ( + f"Streaming text does not match final output_text.\n" + f"Streaming: {streaming_text!r}\n" + f"Final: {final_output_text!r}" + ) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 5a9293f1b9ae5..444275e061c61 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -19,9 +19,9 @@ from vllm.entrypoints.openai.protocol import ( ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.outputs import CompletionOutput, RequestOutput from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers import ToolParserManager from vllm.v1.engine.async_llm import AsyncLLM from ...utils import RemoteOpenAIServer @@ -877,7 +877,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the first turn's input req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, [ @@ -905,7 +905,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the second turn's input req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) verify_harmony_messages( input_messages_2, [ @@ -927,7 +927,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the first turn's input req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, [ @@ -971,7 +971,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the second turn's input req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) verify_harmony_messages( input_messages_2, [ @@ -1008,7 +1008,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the first turn's input req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, [ @@ -1052,7 +1052,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the second turn's input req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) verify_harmony_messages( input_messages_2, [ @@ -1089,7 +1089,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the first turn's input req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, [ @@ -1133,7 +1133,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the second turn's input req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) verify_harmony_messages( input_messages_2, [ @@ -1183,7 +1183,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the third turn's input req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3) + input_messages_3, _ = serving_chat._make_request_with_harmony(req_3) verify_harmony_messages( input_messages_3, [ @@ -1246,7 +1246,7 @@ class TestServingChatWithHarmony: # Test the Harmony messages for the fourth turn's input req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4) + input_messages_4, _ = serving_chat._make_request_with_harmony(req_4) verify_harmony_messages( input_messages_4, [ @@ -1295,7 +1295,7 @@ class TestServingChatWithHarmony: }, ] req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, @@ -1327,7 +1327,7 @@ class TestServingChatWithHarmony: }, ] req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, @@ -1357,7 +1357,7 @@ class TestServingChatWithHarmony: }, ] req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) - input_messages, _, _ = serving_chat._make_request_with_harmony(req) + input_messages, _ = serving_chat._make_request_with_harmony(req) verify_harmony_messages( input_messages, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 956a06dc5487c..192c7cafb7493 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -10,7 +10,7 @@ import pytest from vllm.config import ModelConfig from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer @pytest.fixture() diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index cf00f0a042241..7d03dccec30de 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import ( extract_tool_types, ) from vllm.entrypoints.tool_server import ToolServer -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt class MockConversationContext(ConversationContext): @@ -237,7 +237,7 @@ class TestValidateGeneratorInput: """Test _validate_generator_input with valid prompt length""" # Create an engine prompt with valid length (less than max_model_len) valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len - engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids) # Call the method result = serving_responses_instance._validate_generator_input(engine_prompt) @@ -247,7 +247,7 @@ class TestValidateGeneratorInput: # create an invalid engine prompt invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len - engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids) # Call the method result = serving_responses_instance._validate_generator_input(engine_prompt) diff --git a/tests/entrypoints/openai/test_sparse_tensor_validation.py b/tests/entrypoints/openai/test_sparse_tensor_validation.py new file mode 100644 index 0000000000000..907c82b57dead --- /dev/null +++ b/tests/entrypoints/openai/test_sparse_tensor_validation.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Sparse tensor validation in embedding APIs. + +Tests verify that malicious sparse tensors are rejected before they can trigger +out-of-bounds memory writes during to_dense() operations. +""" + +import base64 +import io + +import pytest +import torch + +from vllm.entrypoints.renderer import CompletionRenderer +from vllm.multimodal.audio import AudioEmbeddingMediaIO +from vllm.multimodal.image import ImageEmbeddingMediaIO + + +def _encode_tensor(tensor: torch.Tensor) -> bytes: + """Helper to encode a tensor as base64 bytes.""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return base64.b64encode(buffer.read()) + + +def _create_malicious_sparse_tensor() -> torch.Tensor: + """ + Create a malicious sparse COO tensor with out-of-bounds indices. + + This tensor has indices that point beyond the declared shape, which would + cause an out-of-bounds write when converted to dense format without + validation. + """ + # Create a 3x3 sparse tensor but with indices pointing to (10, 10) + indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape + values = torch.tensor([1.0]) + shape = (3, 3) + + # Create sparse tensor (this will be invalid) + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_sparse_tensor() -> torch.Tensor: + """Create a valid sparse COO tensor for baseline testing.""" + indices = torch.tensor([[0, 1, 2], [0, 1, 2]]) + values = torch.tensor([1.0, 2.0, 3.0]) + shape = (3, 3) + + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_dense_tensor() -> torch.Tensor: + """Create a valid dense tensor for baseline testing.""" + return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size) + + +class TestPromptEmbedsValidation: + """Test sparse tensor validation in prompt embeddings (Completions API).""" + + def test_valid_dense_tensor_accepted(self, model_config): + """Baseline: Valid dense tensors should work normally.""" + renderer = CompletionRenderer(model_config) + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = renderer.load_prompt_embeds(encoded) + assert len(result) == 1 + assert result[0]["prompt_embeds"].shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = ImageEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self, model_config): + """Security: Malicious sparse tensors should be rejected.""" + renderer = CompletionRenderer(model_config) + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + renderer.load_prompt_embeds(encoded) + + # Error should indicate sparse tensor validation failure + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_extremely_large_indices_rejected(self, model_config): + """Security: Sparse tensors with extremely large indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with indices far beyond reasonable bounds + indices = torch.tensor([[999999], [999999]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + def test_negative_indices_rejected(self, model_config): + """Security: Sparse tensors with negative indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with negative indices + indices = torch.tensor([[-1], [-1]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + +class TestImageEmbedsValidation: + """Test sparse tensor validation in image embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = ImageEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestAudioEmbedsValidation: + """Test sparse tensor validation in audio embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = AudioEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should be converted successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.is_sparse is False + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestSparseTensorValidationIntegration: + """ + These tests verify the complete attack chain is blocked at all entry points. + """ + + def test_attack_scenario_completions_api(self, model_config): + """ + Simulate a complete attack through the Completions API. + + Attack scenario: + 1. Attacker crafts malicious sparse tensor + 2. Encodes it as base64 + 3. Sends to /v1/completions with prompt_embeds parameter + 4. Server should reject before memory corruption occurs + """ + renderer = CompletionRenderer(model_config) + + # Step 1-2: Attacker creates malicious payload + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + # Step 3-4: Server processes and should reject + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(attack_payload) + + def test_attack_scenario_chat_api_image(self): + """ + Simulate attack through Chat API with image_embeds. + + Verifies the image embeddings path is protected. + """ + io_handler = ImageEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_attack_scenario_chat_api_audio(self): + """ + Simulate attack through Chat API with audio_embeds. + + Verifies the audio embeddings path is protected. + """ + io_handler = AudioEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_multiple_valid_embeddings_in_batch(self, model_config): + """ + Regression test: Multiple valid embeddings should still work. + + Ensures the fix doesn't break legitimate batch processing. + """ + renderer = CompletionRenderer(model_config) + + valid_tensors = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should process all without error + result = renderer.load_prompt_embeds(valid_tensors) + assert len(result) == 3 + + def test_mixed_valid_and_malicious_rejected(self, model_config): + """ + Security: Batch with one malicious tensor should be rejected. + + Even if most tensors are valid, a single malicious one should + cause rejection of the entire batch. + """ + renderer = CompletionRenderer(model_config) + + mixed_batch = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_malicious_sparse_tensor()), # Malicious + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should fail on the malicious tensor + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(mixed_batch) + + +# Pytest fixtures +@pytest.fixture +def model_config(): + """Mock ModelConfig for testing.""" + from vllm.config import ModelConfig + + return ModelConfig( + model="facebook/opt-125m", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float32", + seed=0, + enable_prompt_embeds=True, # Required for prompt embeds tests + ) diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 3c507ee0a3fa7..8bf729c517f7a 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client): ) assert transcription.segments is not None assert len(transcription.segments) > 0 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(whisper_client, mary_had_lamb): + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(MODEL_NAME) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index d7d407484f16d..2c577237691ab 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -227,3 +227,36 @@ async def test_long_audio_request(foscolo, client_and_model): ) out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 + + +@pytest.mark.asyncio +async def test_audio_with_max_tokens(mary_had_lamb, client_and_model): + client, model_name = client_and_model + transcription = await client.audio.translations.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": 1}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) == 1 + # max_completion_tokens > max_model_len + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + response_format="text", + temperature=0.0, + extra_body={"max_completion_tokens": int(1e6)}, + ) + out = json.loads(transcription) + out_text = out["text"] + print(out_text) + out_tokens = tok(out_text, add_special_tokens=False)["input_ids"] + assert len(out_tokens) < 450 # ~Whisper max output len diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py index 02c5189d0f6c1..6ac48317e8bc6 100644 --- a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager SIMPLE_ARGS_DICT = { "action": "create", diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index ce6727bb04f6c..8600aaf639431 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -6,8 +6,8 @@ import json import pytest from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from ....utils import RemoteOpenAIServer diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bdd5344652c4b..3944575321391 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py index 6c286ca90ce48..3ce7801b45975 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch import pytest from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation -from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser @pytest.fixture diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8aa88a007188f..3bd1ca7f528d0 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # Test cases similar to pythonic parser but with Llama4 specific format SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" diff --git a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py index a0b9a3c563bc2..3774b3d1833e9 100644 --- a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index 52202c55e8405..c4cad17fd2d01 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 2d4f5f1734102..0b32e5f899ff4 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser class StreamingToolReconstructor: diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 527322c71ae4b..a87a4c35d3dc7 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -29,7 +29,8 @@ from vllm.multimodal.utils import ( encode_image_base64, encode_video_base64, ) -from vllm.tokenizers import MistralTokenizer, get_tokenizer +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.serial_utils import tensor2base64 from ..models.registry import HF_EXAMPLE_MODELS @@ -796,9 +797,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( "content": "<|image_1|>\nWhat's in this image?", } ] + assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @@ -825,10 +830,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( # Should have audio in mm_data as None (UUID provided) assert mm_data is not None assert "audio" in mm_data - assert mm_data["audio"] is None + assert isinstance(mm_data["audio"], list) + assert len(mm_data["audio"]) == 1 + assert mm_data["audio"][0] is None + # UUID should be recorded - assert mm_uuids is not None - assert "audio" in mm_uuids _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid]) @@ -1121,10 +1127,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( mm_data = await mm_future assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) +def test_parse_chat_messages_empty_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that empty dictionary for image_embeds is handled without errors.""" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": {}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + + # Verify mm_data contains an empty dictionary of embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == 0 + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that multiple dictionaries for image_embeds is handled without errors.""" + # Create two sample image embedding tensors + batch_size = 2 + image_embedding_1 = torch.randn(batch_size, 256, 1024) + image_embedding_2 = torch.randn(batch_size, 3) + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "image_embedding_1": tensor2base64(p), + "image_embedding_2": tensor2base64(i), + }, + } + for p, i in zip(image_embedding_1, image_embedding_2) + ] + + [ + {"type": "text", "text": "Describe these two images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", + } + ] + + # Verify mm_data contains a dictionary of multi-embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == batch_size + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor) + assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape + assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor) + assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None]) + + @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_async( phi3v_model_config, diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md index 29c5199e1e87a..dcbfd85bfeee8 100644 --- a/tests/evals/gsm8k/README.md +++ b/tests/evals/gsm8k/README.md @@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation, ### Run tests with pytest (like buildkite) ```bash -pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ - --config-list-file=configs/models-small.txt \ - --tp-size=1 +pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt ``` ### Run standalone evaluation script @@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct" accuracy_threshold: 0.54 # Minimum expected accuracy num_questions: 1319 # Number of questions (default: full test set) num_fewshot: 5 # Few-shot examples from train set -max_model_len: 4096 # Model context length +server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments +env: # Environment variables (optional) + VLLM_USE_FLASHINFER_MOE_FP4: "1" ``` + +The `server_args` field accepts any arguments that can be passed to `vllm serve`. + +The `env` field accepts a dictionary of environment variables to set for the server process. diff --git a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml index 7ec6a1e0be27f..72fa7e8a38c73 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml @@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" accuracy_threshold: 0.72 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 - +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml index caa0448f23d48..b7b59e9dcd5ce 100644 --- a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" accuracy_threshold: 0.74 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml index 615aa69a2d2b6..8b3c9ff645e87 100644 --- a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" accuracy_threshold: 0.31 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index 9297bf6ddf2d3..4a1b1948acac8 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml index 5319ada30f645..5ce3af8be346a 100644 --- a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" accuracy_threshold: 0.60 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml index c39fb979d98ac..5452ebe753f04 100644 --- a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8" accuracy_threshold: 0.375 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 \ No newline at end of file +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml index 6b7bdd1e65bb3..f162aa8bfe5b0 100644 --- a/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml @@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4" accuracy_threshold: 0.89 num_questions: 1319 num_fewshot: 5 -max_model_len: 4096 - +server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml new file mode 100644 index 0000000000000..673b473f817eb --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml @@ -0,0 +1,12 @@ +model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4" +accuracy_threshold: 0.75 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --enforce-eager + --max-model-len 4096 + --tensor-parallel-size 2 + --enable-expert-parallel + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' +env: + VLLM_USE_FLASHINFER_MOE_FP4: "1" diff --git a/tests/evals/gsm8k/configs/models-blackwell.txt b/tests/evals/gsm8k/configs/models-blackwell.txt index 3c9b1084de7bc..39978aa6ffbe9 100644 --- a/tests/evals/gsm8k/configs/models-blackwell.txt +++ b/tests/evals/gsm8k/configs/models-blackwell.txt @@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-CT.yaml DeepSeek-V2-Lite-Instruct-FP8.yaml Qwen3-30B-A3B-NVFP4.yaml +Qwen3-Next-80B-A3B-NVFP4-EP2.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py index 1932a13cdfc63..6f25fe6414af4 100644 --- a/tests/evals/gsm8k/conftest.py +++ b/tests/evals/gsm8k/conftest.py @@ -11,14 +11,12 @@ def pytest_addoption(parser): default="configs/models-small.txt", help="File containing list of config files to test", ) - parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size") def pytest_generate_tests(metafunc): """Generate test parameters from config files.""" if "config_filename" in metafunc.fixturenames: config_list_file = metafunc.config.getoption("--config-list-file") - tp_size = metafunc.config.getoption("--tp-size") # Handle both relative and absolute paths config_list_path = Path(config_list_file) @@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc): # Generate test parameters if config_files: metafunc.parametrize( - ["config_filename", "tp_size"], - [(config_file, int(tp_size)) for config_file in config_files], - ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], + "config_filename", + config_files, + ids=[config_file.stem for config_file in config_files], ) else: print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index b5d67df7bf3db..ea6715f5cb532 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script. Replacement for lm-eval-harness with better performance and control. Usage: -pytest -s -v test_gsm8k_correctness.py \ - --config-list-file=configs/models-small.txt \ - --tp-size=1 +pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt """ +import shlex + import yaml from tests.utils import RemoteOpenAIServer from .gsm8k_eval import evaluate_gsm8k -RTOL = 0.08 # Relative tolerance for accuracy comparison +TOL = 0.08 # Absolute tolerance for accuracy comparison -def launch_gsm8k_eval(eval_config, server_url, tp_size): - """Launch GSM8K evaluation using our isolated script.""" +def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict: + """Run GSM8K evaluation using our isolated script.""" # Extract host and port from server URL if "://" in server_url: server_url = server_url.split("://")[1] host_port = server_url.split("/")[0] # Remove path if present if ":" in host_port: - host, port = host_port.split(":") - port = int(port) + host, p = host_port.split(":") + port = int(p) else: host = host_port port = 8000 @@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size): return results -def test_gsm8k_correctness_param(config_filename, tp_size): +def test_gsm8k_correctness(config_filename): """Test GSM8K correctness for a given model configuration.""" eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) - # Server arguments - server_args = [ - "--max-model-len", - str(eval_config.get("max_model_len", 4096)), - "--enforce-eager", - "--trust-remote-code", - "--tensor-parallel-size", - str(tp_size), - ] + # Parse server arguments from config (use shlex to handle quoted strings) + server_args_str = eval_config.get("server_args", "") + server_args = shlex.split(server_args_str) if server_args_str else [] + + # Add standard server arguments + server_args.extend( + [ + "--trust-remote-code", + ] + ) env_dict = eval_config.get("env", None) + print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}") + print(f"Expected metric threshold: {eval_config['accuracy_threshold']}") + print(f"Number of questions: {eval_config['num_questions']}") + print(f"Number of few-shot examples: {eval_config['num_fewshot']}") + print(f"Server args: {' '.join(server_args)}") + # Launch server and run evaluation with RemoteOpenAIServer( - eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480 + eval_config["model_name"], + server_args, + env_dict=env_dict, + max_wait_seconds=600, ) as remote_server: server_url = remote_server.url_for("v1") + print(f"Server started at: {server_url}") - results = launch_gsm8k_eval(eval_config, server_url, tp_size) + results = run_gsm8k_eval(eval_config, server_url) - # Check accuracy against threshold - measured_accuracy = results["accuracy"] - expected_accuracy = eval_config["accuracy_threshold"] + measured_metric = results["accuracy"] + expected_metric = eval_config["accuracy_threshold"] print(f"GSM8K Results for {eval_config['model_name']}:") - print(f" Accuracy: {measured_accuracy:.3f}") - print(f" Expected: {expected_accuracy:.3f}") + print(f" Measured metric: {measured_metric:.4f}") + print(f" Expected metric: {expected_metric:.4f}") + print(f" Tolerance: {TOL:.4f}") print(f" Questions: {results['num_questions']}") print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Latency: {results['latency']:.1f}s") print(f" QPS: {results['questions_per_second']:.1f}") - # Verify accuracy is within tolerance - assert measured_accuracy >= expected_accuracy - RTOL, ( - f"Accuracy too low: {measured_accuracy:.3f} < " - f"{expected_accuracy:.3f} - {RTOL:.3f}" + # Verify metric is within tolerance + assert measured_metric >= expected_metric - TOL, ( + f"GSM8K metric too low: {measured_metric:.4f} < " + f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" ) print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index a60f4e385a893..784c16304a286 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -32,8 +32,8 @@ def cal_diff( CUTLASS_MLA_UNSUPPORTED_REASON = ( - "Cutlass MLA Requires compute capability of 10 or above." - if not current_platform.is_device_capability(100) + "Cutlass MLA Requires compute capability of 100 or above." + if not current_platform.is_device_capability_family(100) else "Cutlass MLA is supported" ) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 98ea40608b468..06a7085a82ba0 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True ) @@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 1e-1, 2e-1 + rtol, atol = 3e-1, 4e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 4e-2, 6e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: diff --git a/tests/kernels/core/test_apply_rotary_emb.py b/tests/kernels/core/test_apply_rotary_emb.py new file mode 100644 index 0000000000000..23c722fa5e638 --- /dev/null +++ b/tests/kernels/core/test_apply_rotary_emb.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for ApplyRotaryEmb CustomOp dispatch behavior. + +This test ensures that RotaryEmbedding classes correctly call the appropriate +ApplyRotaryEmb methods based on the calling context: + +1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native() +2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch) +3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch) +""" + +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ( + CompilationConfig, + VllmConfig, + get_cached_compilation_config, + set_current_vllm_config, +) +from vllm.platforms import current_platform + +CUDA_DEVICES = ["cuda:0"] + + +@dataclass +class RotaryEmbeddingTestCase: + """Test case configuration for RotaryEmbedding dispatch tests.""" + + name: str + rope_class: type + rope_kwargs: dict + method_name: str # forward_native, forward_cuda, forward + positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens) + expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native() + expect_forward: bool # Should call ApplyRotaryEmb.forward() + + +def get_test_cases() -> list[RotaryEmbeddingTestCase]: + """Generate test cases for all RotaryEmbedding classes.""" + from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding, + ) + from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding + + common_kwargs = { + "head_size": 128, + "rotary_dim": 128, + "max_position_embeddings": 4096, + "base": 10000, + "is_neox_style": True, + "dtype": torch.bfloat16, + } + + return [ + # MRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_native", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_cuda_1d", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_cuda", + positions_shape=(32,), # 1D triggers apply_rotary_emb path + expect_forward_native=False, + expect_forward=True, + ), + # XDRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="XDRotaryEmbedding.forward", + rope_class=XDRotaryEmbedding, + rope_kwargs={ + **common_kwargs, + "scaling_alpha": 1.0, + "xdrope_section": [16, 16, 16, 16], + }, + method_name="forward", + positions_shape=(4, 32), # 4D for P/W/H/T + expect_forward_native=False, + expect_forward=True, + ), + # Ernie4_5_VLRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="Ernie4_5_VLRotaryEmbedding.forward_native", + rope_class=Ernie4_5_VLRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + ] + + +def run_dispatch_test( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """Run a dispatch test for a RotaryEmbedding class.""" + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"]) + ) + get_cached_compilation_config.cache_clear() + + with set_current_vllm_config(vllm_config): + rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device) + + apply_rotary_emb = rope.apply_rotary_emb + + # Verify custom op is enabled + if test_case.expect_forward_native: + assert ( + apply_rotary_emb._forward_method != apply_rotary_emb.forward_native + ), "Test setup error: ApplyRotaryEmb custom op should be enabled" + + # Setup call tracking + call_tracker = {"forward_native_called": False, "forward_called": False} + original_forward_native = apply_rotary_emb.forward_native + original_forward = apply_rotary_emb.forward + + def tracked_forward_native(*args, **kwargs): + call_tracker["forward_native_called"] = True + return original_forward_native(*args, **kwargs) + + def tracked_forward(*args, **kwargs): + call_tracker["forward_called"] = True + return original_forward(*args, **kwargs) + + apply_rotary_emb.forward_native = tracked_forward_native + apply_rotary_emb.forward = tracked_forward + + try: + num_tokens = test_case.positions_shape[-1] + num_q_heads = 8 + num_kv_heads = 2 + head_size = test_case.rope_kwargs["head_size"] + max_position = test_case.rope_kwargs["max_position_embeddings"] + + positions = torch.randint( + 0, max_position // 4, test_case.positions_shape, device=device + ) + query = torch.randn( + num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device + ) + key = torch.randn( + num_tokens, + num_kv_heads * head_size, + dtype=torch.bfloat16, + device=device, + ) + + # Call the method under test + method = getattr(rope, test_case.method_name) + method(positions, query.clone(), key.clone()) + + # Verify expectations + if test_case.expect_forward_native: + assert call_tracker["forward_native_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward_native()" + ) + if not test_case.expect_forward: + assert not call_tracker["forward_called"], ( + f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). " + "Bug: when +apply_rotary_emb is enabled, forward_native() " + "incorrectly dispatches to CUDA/HIP kernels." + ) + if test_case.expect_forward: + assert call_tracker["forward_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward()" + ) + finally: + apply_rotary_emb.forward_native = original_forward_native + apply_rotary_emb.forward = original_forward + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_rotary_embedding_dispatch( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """ + Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method. + + - forward_native methods should call ApplyRotaryEmb.forward_native() + - forward_cuda/forward methods should call ApplyRotaryEmb.forward() + """ + run_dispatch_test(test_case, device) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index d95c22fdf0a5b..6078ce44cee9f 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -594,7 +594,8 @@ def make_modular_kernel( ) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts + prepare_finalize=prepare_finalize, + fused_experts=fused_experts, ) return modular_kernel diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index d553e2820e5ff..bf4ef2d30466b 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -107,6 +108,19 @@ class TestData: layer.w2_input_scale = a2_scale layer.w13_weight_scale = w13_weight_scale layer.w2_weight_scale = w2_weight_scale + # Setup dummy config. + layer.moe_parallel_config = mk.FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=1, + ep_size=1, + tp_rank=1, + pcp_rank=1, + dp_rank=1, + ep_rank=1, + use_ep=False, + all2all_backend="naive", + ) register_moe_scaling_factors(layer) diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index 5a850dda4f6fd..8fe471d124f43 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( ) >= version.parse("0.8.99") TRTLLM_GEN_MXFP4_AVAILABLE = ( - current_platform.is_cuda() and current_platform.is_device_capability(100) + current_platform.is_cuda() and current_platform.is_device_capability_family(100) ) HOPPER_MXFP4_BF16_AVAILABLE = ( @@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( @pytest.mark.skipif( not ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and has_flashinfer() ), reason="NVIDIA GPU sm100 and flashinfer are required for this test", diff --git a/tests/kernels/quantization/test_awq.py b/tests/kernels/quantization/test_awq.py index efb62ca3799a9..3bf59dea30972 100644 --- a/tests/kernels/quantization/test_awq.py +++ b/tests/kernels/quantization/test_awq.py @@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): qweight = torch.randint( -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 ) - scales = torch.randint( + scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16) + qzeros = torch.randint( -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 ) - qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16) split_k_iters = 8 - opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) + opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters)) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 6628ac650fd5f..f5e1cde94b6e9 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -62,7 +62,7 @@ def test_quantfp8_group_functionality( assert scales_col.stride(1) == batch_size # Test column-major scales consistency - assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -71,7 +71,7 @@ def test_quantfp8_group_functionality( assert scales_cuda.shape == (batch_size, expected_num_groups) # Verify CUDA/native consistency - assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8) # Quantized values should mostly match diff_count = (x_quant_cuda != x_quant_native).sum().item() diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index f4269750feb6b..2fa61f280587f 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files): enable_lora=True, max_loras=4, max_lora_rank=8, + max_num_seqs=2, + max_num_batched_tokens=2048, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), @@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras): enable_lora=True, max_loras=2, max_lora_rank=8, - max_num_seqs=16, + max_num_seqs=2, + max_num_batched_tokens=2048, tensor_parallel_size=2, + gpu_memory_utilization=0.8, fully_sharded_loras=fully_sharded_loras, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 18704fa6e45de..483235ff51291 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -76,11 +76,18 @@ def do_sample( if lora_id else None, ) - # Print the outputs. + lora_request = LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + # The output should include correct lora_request info + if lora_request is not None: + assert output.lora_request.lora_name == lora_request.lora_name + assert output.lora_request.lora_int_id == lora_request.lora_int_id + assert output.lora_request.lora_path == lora_request.lora_path + else: + assert output.lora_request is None generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_texts diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index eb026c2ec0209..bec12eeeb48d5 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -3,7 +3,7 @@ from collections import OrderedDict from typing import NamedTuple -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from huggingface_hub.utils import HfHubHTTPError @@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error( # Hugging Face model identifier with download error path = "org/repo" mock_exist.return_value = False - mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") + mock_snapshot_download.side_effect = HfHubHTTPError( + "failed to query model info", + response=MagicMock(), + ) assert get_adapter_absolute_path(path) == path diff --git a/tests/models/fixtures/audioflamingo3/expected_results_batched.json b/tests/models/fixtures/audioflamingo3/expected_results_batched.json new file mode 100644 index 0000000000000..4dbb107edccb7 --- /dev/null +++ b/tests/models/fixtures/audioflamingo3/expected_results_batched.json @@ -0,0 +1 @@ +{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other.", "(B) To indicate that language cannot express clearly, satirizing the inversion of black and white in the world"], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645], [5349, 8, 2014, 13216, 429, 4128, 4157, 3158, 9355, 11, 7578, 404, 4849, 279, 46488, 315, 3691, 323, 4158, 304, 279, 1879, 151645, 151671]]} \ No newline at end of file diff --git a/tests/models/fixtures/audioflamingo3/expected_results_single.json b/tests/models/fixtures/audioflamingo3/expected_results_single.json new file mode 100644 index 0000000000000..be9233467a20e --- /dev/null +++ b/tests/models/fixtures/audioflamingo3/expected_results_single.json @@ -0,0 +1 @@ +{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]} \ No newline at end of file diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index e2d6271e2faed..0ef4ba2577724 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -5,12 +5,12 @@ import json import pytest -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( +from vllm.sampling_params import SamplingParams +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.mistral_tool_parser import ( MistralToolCall, MistralToolParser, ) -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer from ...utils import check_logprobs_close diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index 2dfc0072126bc..64d42432c74b9 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -68,3 +68,34 @@ def test_modernbert_models( hf_output = torch.tensor(hf_output).cpu().float() vllm_output = torch.tensor(vllm_output).cpu().float() assert torch.allclose(hf_output, vllm_output, atol=1e-2) + + +@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"]) +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_auto_conversion( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.token_classify(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/multimodal/generation/test_audioflamingo3.py b/tests/models/multimodal/generation/test_audioflamingo3.py new file mode 100644 index 0000000000000..d14291a62c346 --- /dev/null +++ b/tests/models/multimodal/generation/test_audioflamingo3.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import pytest + +from tests.models.registry import HF_EXAMPLE_MODELS +from vllm import LLM, SamplingParams + +MODEL_NAME = "nvidia/audio-flamingo-3-hf" + + +def get_fixture_path(filename): + return os.path.join( + os.path.dirname(__file__), "../../fixtures/audioflamingo3", filename + ) + + +@pytest.fixture(scope="module") +def llm(): + # Check if the model is supported by the current transformers version + model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + try: + llm = LLM( + model=MODEL_NAME, + trust_remote_code=True, + dtype="bfloat16", + enforce_eager=True, + limit_mm_per_prompt={"audio": 1}, + ) + return llm + except Exception as e: + pytest.skip(f"Failed to load model {MODEL_NAME}: {e}") + + +def test_single_generation(llm): + fixture_path = get_fixture_path("expected_results_single.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav" + + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "Transcribe the input speech."}, + ], + } + ] + + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + outputs = llm.chat( + messages=messages, + sampling_params=sampling_params, + ) + generated_text = outputs[0].outputs[0].text.strip() + + expected_text = expected["transcriptions"][0] + + assert expected_text in generated_text or generated_text in expected_text + + +def test_batched_generation(llm): + fixture_path = get_fixture_path("expected_results_batched.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + items = [ + { + "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav", + "question": "What is surprising about the relationship " + "between the barking and the music?", + "expected_idx": 0, + }, + { + "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav", + "question": ( + "Why is the philosopher's name mentioned in the lyrics? " + "(A) To express a sense of nostalgia " + "(B) To indicate that language cannot express clearly, " + "satirizing the inversion of black and white in the world " + "(C) To add depth and complexity to the lyrics " + "(D) To showcase the wisdom and influence of the philosopher" + ), + "expected_idx": 1, + }, + ] + + conversations = [] + for item in items: + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": item["audio_url"]}}, + {"type": "text", "text": item["question"]}, + ], + } + ] + conversations.append(messages) + + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + outputs = llm.chat( + messages=conversations, + sampling_params=sampling_params, + ) + + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text.strip() + expected_text = expected["transcriptions"][i] + + assert expected_text in generated_text or generated_text in expected_text diff --git a/tests/models/multimodal/generation/test_vit_backend_functionality.py b/tests/models/multimodal/generation/test_vit_backend_functionality.py new file mode 100644 index 0000000000000..a4e4ce312ddd4 --- /dev/null +++ b/tests/models/multimodal/generation/test_vit_backend_functionality.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Consolidated test for ViT attention backend functionality across multiple models. + +This test validates that each multimodal model can successfully generate outputs +using different ViT attention backends. Tests are parametrized by model and backend. +""" + +from dataclasses import asdict +from typing import Any + +import pytest +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.multimodal.utils import encode_image_base64 +from vllm.multimodal.video import sample_frames_from_video +from vllm.platforms import current_platform + +from ....utils import create_new_process_for_each_test +from ...utils import dummy_hf_overrides + +# Dots.OCR prompt from official repository +# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3 +# ruff: noqa: E501 +DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. + +1. Bbox format: [x1, y1, x2, y2] + +2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. + +3. Text Extraction & Formatting Rules: + - Picture: For the 'Picture' category, the text field should be omitted. + - Formula: Format its text as LaTeX. + - Table: Format its text as HTML. + - All Others (Text, Title, etc.): Format their text as Markdown. + +4. Constraints: + - The output text must be the original text from the image, with no translation. + - All layout elements must be sorted according to human reading order. + +5. Final Output: The entire output must be a single JSON object. +""" + +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +# Model configurations +MODEL_CONFIGS: dict[str, dict[str, Any]] = { + "dots_ocr": { + "model_name": "rednote-hilab/dots.ocr", + "interface": "llm_chat", + "max_model_len": 32768, + "max_num_seqs": 1, + "limit_mm_per_prompt": {"image": 1}, + "sampling_params": { + "temperature": 0.1, + "max_tokens": 16384, + "top_p": 0.9, + "stop_token_ids": None, + }, + "use_specific_image": "stop_sign", + "prompt_builder": "build_dots_ocr_prompt", + "output_validator": lambda x: len(x) > 10 and "stop" in x.lower(), + }, + "ernie45_vl": { + "model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT", + "interface": "llm_generate", + "max_model_len": 16384, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "glm4_1v": { + "model_name": "zai-org/GLM-4.1V-9B-Thinking", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "keye_vl": { + "model_name": "Kwai-Keye/Keye-VL-8B-Preview", + "interface": "llm_generate", + "max_model_len": 8192, + "max_num_seqs": 5, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "supported_backends": { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "ovis2_5": { + "model_name": "AIDC-AI/Ovis2.5-2B", + "interface": "llm_generate", + "max_model_len": 8192, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "prompt_builder": "build_ovis_prompt", + "question": "What is the content of each image?", + }, + "qwen2_5_vl": { + "model_name": "Qwen/Qwen2.5-VL-3B-Instruct", + "interface": "vllm_runner", + "media_type": "video", + "max_model_len": 4000, + "max_num_seqs": 1, + "limit_mm_per_prompt": {"video": 1}, + "sampling_params": { + "max_tokens": 128, + }, + "runner_kwargs": { + "runner": "generate", + "dtype": "bfloat16", + }, + "video_params": { + "num_frames": 16, + "pruning_rates": [0.0, 0.75], + }, + }, + "qwen2_5_omni": { + "model_name": "Qwen/Qwen2.5-Omni-3B", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3}, + "sampling_params": { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 16384, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "qwen3_omni": { + "model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3}, + "sampling_params": { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 16384, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, +} + + +# Prompt builder functions +def build_dots_ocr_prompt(images, config): + """Build Dots.OCR specific prompt with OCR instructions.""" + # Use only stop_sign image for Dots.OCR + image = images[0] # Already filtered to stop_sign + + image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}" + + placeholders = [{"type": "image_url", "image_url": {"url": image_url}}] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + { + "type": "text", + "text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}", + }, + ], + }, + ] + + return messages + + +def build_processor_prompt(images, config): + """Build prompt using AutoProcessor.apply_chat_template().""" + processor = AutoProcessor.from_pretrained( + config["model_name"], trust_remote_code=True + ) + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images + ] + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": config["question"]}, + ], + }, + ] + + return processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + +def build_ovis_prompt(images, config): + """Build Ovis2.5 specific prompt with custom format.""" + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images + ] + + placeholders = "\n".join( + f"Image-{i}: \n" for i, _ in enumerate(image_urls, start=1) + ) + + return ( + f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def build_qwen2_5_video_prompt(): + """Build Qwen2.5-VL video prompt with EVS placeholder.""" + return ( + f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{VIDEO_PLACEHOLDER}" + "Describe this video with a short sentence (no more than 20 words)" + "<|im_end|><|im_start|>assistant\n" + ) + + +# Handler functions +def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets): + """Standard LLM.generate() interface handler.""" + images = [asset.pil_image for asset in image_assets] + + # Build prompt + if config.get("use_processor"): + prompt = build_processor_prompt(images, config) + else: + prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt") + prompt_builder = globals()[prompt_builder_name] + prompt = prompt_builder(images, config) + + # Determine limit_mm_per_prompt + limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)}) + + # Create engine + engine_args = EngineArgs( + model=config["model_name"], + trust_remote_code=True, + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=limit_mm_per_prompt, + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + ) + + engine_dict = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_dict) + + # Generate + sampling_params = SamplingParams(**config["sampling_params"]) + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + # Validate + for o in outputs: + generated_text = o.outputs[0].text + validator = config.get("output_validator", lambda x: len(x) > 10) + assert validator(generated_text), ( + f"Validation failed for {config['model_name']}: {generated_text}" + ) + + +def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets): + """LLM.chat() interface handler for Dots.OCR.""" + # Filter to stop_sign image only + stop_sign_image = [ + asset.pil_image for asset in image_assets if asset.name == "stop_sign" + ][0] + + # Build messages + messages = build_dots_ocr_prompt([stop_sign_image], config) + + # Create engine + engine_args = EngineArgs( + model=config["model_name"], + trust_remote_code=True, + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=config["limit_mm_per_prompt"], + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + ) + + engine_dict = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_dict) + + # Generate using chat + sampling_params = SamplingParams(**config["sampling_params"]) + outputs = llm.chat(messages=messages, sampling_params=sampling_params) + + # Validate + for o in outputs: + generated_text = o.outputs[0].text + validator = config.get("output_validator", lambda x: len(x) > 10) + assert validator(generated_text), ( + f"Validation failed for {config['model_name']}: {generated_text}" + ) + + +def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner): + """Video test with EVS (Efficient Video Sampling) handler.""" + for pruning_rate in config["video_params"]["pruning_rates"]: + num_frames = config["video_params"]["num_frames"] + + # Sample frames from video + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + # Build prompt and prepare video + prompt = build_qwen2_5_video_prompt() + prompts = [prompt] + videos = [sampled_vids[0]] + + # Run with vllm_runner context manager + with vllm_runner( + config["model_name"], + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=config["limit_mm_per_prompt"], + tensor_parallel_size=1, + video_pruning_rate=pruning_rate, + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + **config["runner_kwargs"], + ) as vllm_model: + outputs = vllm_model.generate_greedy( + prompts, + config["sampling_params"]["max_tokens"], + videos=videos, + ) + + # Validate output + assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}" + output_ids, output_text = outputs[0] + assert len(output_ids) > 0, "Generated no output IDs" + assert len(output_text) > 0, "Generated empty text" + assert isinstance(output_text, str), ( + f"Output is not string: {type(output_text)}" + ) + + +# Main test function +@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys())) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +@pytest.mark.skip(reason="Broken test due to memory segmentation fault") +@create_new_process_for_each_test() +def test_vit_backend_functionality( + model_key: str, + mm_encoder_attn_backend: AttentionBackendEnum | None, + image_assets, + video_assets, + vllm_runner, + request, +): + """Test ViT attention backend functionality for multimodal models. + + This test validates that each model can successfully generate outputs + using different ViT attention backends. The test: + 1. Filters unsupported backends per model + 2. Applies appropriate GPU marks + 3. Routes to the correct test handler based on interface + 4. Validates output meets minimum requirements + """ + config = MODEL_CONFIGS[model_key] + + # Step 1: Backend filtering + if ( + "supported_backends" in config + and mm_encoder_attn_backend is not None + and mm_encoder_attn_backend not in config["supported_backends"] + ): + pytest.skip( + f"{model_key} does not support {mm_encoder_attn_backend} backend now." + ) + + # Step 2: Apply GPU marks dynamically + if "gpu_marks" in config: + for mark in config["gpu_marks"]: + request.applymarker(mark) + + # Step 3: Route to appropriate handler + if config.get("media_type") == "video": + run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner) + elif config["interface"] == "llm_chat": + run_llm_chat_test(config, mm_encoder_attn_backend, image_assets) + elif config["interface"] == "llm_generate": + run_llm_generate_test(config, mm_encoder_attn_backend, image_assets) + else: + raise ValueError(f"Unknown interface: {config['interface']}") diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index 9e9087cb0fc4d..0eaef49e2395c 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -9,7 +9,7 @@ from mistral_common.audio import Audio from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.messages import UserMessage -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from ....conftest import AudioTestAssets from ....utils import RemoteOpenAIServer diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 592862c2a0bb0..b206995a9cecc 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -1,150 +1,146 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +import librosa import pytest +from transformers import AutoModelForSpeechSeq2Seq -from vllm import SamplingParams from vllm.assets.audio import AudioAsset +from vllm.platforms import current_platform -from ....conftest import VllmRunner +from ....conftest import HfRunner, PromptAudioInput, VllmRunner from ....utils import create_new_process_for_each_test, multi_gpu_test +from ...registry import HF_EXAMPLE_MODELS +from ...utils import check_logprobs_close -PROMPTS = [ - { - "prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - "multi_modal_data": { - "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, - }, - }, - { # Test explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": AudioAsset("winning_call").audio_and_sample_rate, - }, - }, - "decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - }, -] +VLLM_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" +HF_PROMPT = "" +# Whisper expects 16kHz audio +WHISPER_SAMPLE_RATE = 16000 -EXPECTED = { - "openai/whisper-tiny": [ - " He has birth words I spoke in the original corner of that. And a" - " little piece of black coat poetry. Mary had a little sandwich," - " sweet, with white and snow. And everyone had it very went the last" - " would sure to go.", - " >> And the old one, fit John the way to Edgar Martinez. >> One more" - " to line down the field line for our base camp. Here comes joy. Here" - " is June and the third base. They're going to wave him in. The throw" - " to the plate will be late. The Mariners are going to play for the" - " American League Championship. I don't believe it. It just continues" - " by all five.", - ], - "openai/whisper-small": [ - " The first words I spoke in the original pornograph. A little piece" - " of practical poetry. Mary had a little lamb, its fleece was quite a" - " slow, and everywhere that Mary went the lamb was sure to go.", - " And the old one pitch on the way to Edgar Martinez one month. Here" - " comes joy. Here is Junior to third base. They're gonna wave him" - " in. The throw to the plate will be late. The Mariners are going to" - " play for the American League Championship. I don't believe it. It" - " just continues. My, oh my.", - ], - "openai/whisper-medium": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its fleece was quite as" - " slow, and everywhere that Mary went the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez swung on the line" - " down the left field line for Obeyshev. Here comes Joy. Here is" - " Jorgen at third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh" - " my.", - ], - "openai/whisper-large-v3": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its feet were quite as" - " slow, and everywhere that Mary went, the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line." - " Now the left field line for a base hit. Here comes Joy. Here is" - " Junior to third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh," - " my.", - ], - "openai/whisper-large-v3-turbo": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its streets were quite" - " as slow, and everywhere that Mary went the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line" - " down the left field line for a base hit. Here comes Joy. Here is" - " Junior to third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh," - " my.", - ], -} + +@pytest.fixture(autouse=True) +def use_spawn_for_whisper(monkeypatch): + """Whisper has issues with forked workers, use spawn instead.""" + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") def run_test( + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + inputs: Sequence[tuple[list[str], list[str], PromptAudioInput]], model: str, *, + max_model_len: int, + dtype: str, + max_tokens: int, + num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: str | None = None, - dtype: str = "half", + enforce_eager: bool = True, ) -> None: - prompt_list = PROMPTS * 10 - expected_list = EXPECTED[model] * 10 + """Inference result should be the same between hf and vllm. + All the audio fixtures for the test are from AudioAsset. + For huggingface runner, we provide the audio as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + """ with vllm_runner( model, dtype=dtype, - max_model_len=448, + max_model_len=max_model_len, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - # TODO (NickLucche) figure out output differences with non-eager and re-enable - enforce_eager=True, + limit_mm_per_prompt={"audio": 2}, + enforce_eager=enforce_eager, + disable_custom_all_reduce=True, ) as vllm_model: - llm = vllm_model.llm + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs( + vllm_prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + ) + for vllm_prompts, _, audios in inputs + ] - sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - max_tokens=200, + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit( + hf_prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + ) + for _, hf_prompts, audios in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) - outputs = llm.generate(prompt_list, sampling_params) - for output, expected in zip(outputs, expected_list): - print(output.outputs[0].text) - assert output.outputs[0].text == expected +@pytest.fixture +def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]: + audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + inputs = [] + for asset in audio_assets: + audio, orig_sr = asset.audio_and_sample_rate + # Resample to Whisper's expected sample rate (16kHz) + if orig_sr != WHISPER_SAMPLE_RATE: + audio = librosa.resample( + audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE + ) + # vLLM prompts, HF prompts, audio inputs + inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)])) + return inputs + + +def check_model_available(model: str) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") @pytest.mark.core_model -@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) -@pytest.mark.parametrize("dtype", ["half"]) -@create_new_process_for_each_test() -def test_models(vllm_runner, model, dtype) -> None: - run_test( - vllm_runner, - model, - tensor_parallel_size=1, - dtype=dtype, - ) - - @pytest.mark.cpu_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @pytest.mark.parametrize("dtype", ["half"]) -def test_models_cpu(vllm_runner, model, dtype) -> None: - # @create_new_process_for_each_test() does not work for some runners - # TODO: to fix cpu privilege issues in run-cpu-test-arm.sh +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@create_new_process_for_each_test("spawn") +def test_models( + hf_runner, + vllm_runner, + model: str, + dtype: str, + num_logprobs: int, + input_audios, + enforce_eager: bool, +) -> None: + check_model_available(model) + if current_platform.is_cpu() and not enforce_eager: + pytest.skip("Skipping test for CPU with non-eager mode") run_test( + hf_runner, vllm_runner, + input_audios, model, - tensor_parallel_size=1, dtype=dtype, + max_model_len=448, + max_tokens=200, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + enforce_eager=enforce_eager, ) @@ -152,15 +148,31 @@ def test_models_cpu(vllm_runner, model, dtype) -> None: @pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@create_new_process_for_each_test() +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [200]) +@pytest.mark.parametrize("num_logprobs", [5]) +@create_new_process_for_each_test("spawn") def test_models_distributed( + hf_runner, vllm_runner, - model, - distributed_executor_backend, + model: str, + distributed_executor_backend: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + input_audios, ) -> None: + check_model_available(model) run_test( + hf_runner, vllm_runner, + input_audios, model, + dtype=dtype, + max_model_len=448, + max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, + enforce_eager=False, ) diff --git a/tests/models/multimodal/processing/test_audioflamingo3.py b/tests/models/multimodal/processing/test_audioflamingo3.py new file mode 100644 index 0000000000000..d7c00516ffead --- /dev/null +++ b/tests/models/multimodal/processing/test_audioflamingo3.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from transformers import PretrainedConfig + +from tests.models.registry import HF_EXAMPLE_MODELS + + +class MockAudioFlamingo3Config(PretrainedConfig): + model_type = "audioflamingo3" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.audio_config = PretrainedConfig() + self.text_config = PretrainedConfig() + + +class MockAudioFlamingo3Processor: + def __init__(self): + self.audio_token = "" + self.audio_token_id = 12345 + self.feature_extractor = MockFeatureExtractor() + + def __call__(self, text=None, audios=None, **kwargs): + return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]} + + +class MockFeatureExtractor: + def __init__(self): + self.sampling_rate = 16000 + self.chunk_length = 30 + + +@pytest.fixture +def mock_ctx(): + config = MockAudioFlamingo3Config() + + ctx = MagicMock() + ctx.get_hf_config.return_value = config + ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor() + ctx.model_config.hf_config = config + return ctx + + +@pytest.fixture(autouse=True) +def check_transformers_version(): + # Check if the model is supported by the current transformers version + model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + +def test_audio_chunk_counting(mock_ctx): + from vllm.model_executor.models.audioflamingo3 import ( + AudioFlamingo3DummyInputsBuilder, + AudioFlamingo3MultiModalProcessor, + AudioFlamingo3ProcessingInfo, + ) + + info = AudioFlamingo3ProcessingInfo(mock_ctx) + processor = AudioFlamingo3MultiModalProcessor( + info, AudioFlamingo3DummyInputsBuilder(info) + ) + + sr = 16000 + audio_1 = np.zeros(30 * sr) + audio_2 = np.zeros(45 * sr) + + mm_data = {"audio": [audio_1, audio_2]} + prompt = "<|user|>Listen.<|end|>" + + from vllm.multimodal.processing import BaseMultiModalProcessor + + def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs): + return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)} + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call) + + processed = processor._call_hf_processor(prompt, mm_data, {}, {}) + + chunk_counts = processed["chunk_counts"] + + assert chunk_counts[0].item() == 1 + assert chunk_counts[1].item() == 2 + assert len(chunk_counts) == 2 + + +def test_dummy_data_generation(mock_ctx): + from vllm.model_executor.models.audioflamingo3 import ( + AudioFlamingo3DummyInputsBuilder, + AudioFlamingo3ProcessingInfo, + ) + + info = AudioFlamingo3ProcessingInfo(mock_ctx) + builder = AudioFlamingo3DummyInputsBuilder(info) + + mm_counts = {"audio": 2} + dummy_data = builder.get_dummy_mm_data(100, mm_counts, None) + + assert "audio" in dummy_data + assert len(dummy_data["audio"]) == 2 + + expected_len = 600 * 16000 + assert len(dummy_data["audio"][0]) == expected_len diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 2e032ac4ca526..67861ebfc44e4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -22,11 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext -from vllm.tokenizers import ( - MistralTokenizer, - TokenizerLike, - cached_tokenizer_from_config, -) +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from ....multimodal.utils import random_audio, random_image, random_video from ...registry import ( diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 5d489549c5b46..cb875436857cf 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -8,6 +8,7 @@ from typing import Any, TypeAlias import numpy as np import pytest +import torch import torch.nn as nn from PIL import Image @@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_dtype +from ....utils import create_new_process_for_each_test from ...registry import HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides from .test_common import get_model_ids_to_test, get_text_token_prompts @@ -136,6 +138,7 @@ def create_batched_mm_kwargs( ) +# TODO(Isotr0py): Don't initalize model during test @contextmanager def initialize_dummy_model( model_cls: type[nn.Module], @@ -150,16 +153,21 @@ def initialize_dummy_model( backend="nccl", ) initialize_model_parallel(tensor_model_parallel_size=1) + + current_device = torch.get_default_device() vllm_config = VllmConfig(model_config=model_config) with set_current_vllm_config(vllm_config=vllm_config): with set_default_torch_dtype(model_config.dtype): + torch.set_default_device(current_platform.device_type) model = model_cls(vllm_config=vllm_config) + torch.set_default_device(current_device) yield model del model cleanup_dist_env_and_memory() +@create_new_process_for_each_test() @pytest.mark.parametrize("model_id", get_model_ids_to_test()) def test_model_tensor_schema(model_id: str): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) diff --git a/tests/models/registry.py b/tests/models/registry.py index 18056a9657e82..c5d72b5d581b9 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -356,7 +356,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MistralLarge3ForCausalLM": _HfExamplesInfo( - "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False + "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4" ), "MixtralForCausalLM": _HfExamplesInfo( "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -573,12 +573,17 @@ _AUTOMATIC_CONVERTED_MODELS = { "Qwen3ForSequenceClassification": _HfExamplesInfo( "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" ), + "Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"), } _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), + "AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo( + "nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev" + ), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"), + "BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"), "BeeForConditionalGeneration": _HfExamplesInfo( "Open-Bee/Bee-8B-RL", trust_remote_code=True, @@ -635,7 +640,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ), "HunYuanVLForConditionalGeneration": _HfExamplesInfo( "tencent/HunyuanOCR", - is_available_online=False, + hf_overrides={"num_experts": 0}, ), "Idefics3ForConditionalGeneration": _HfExamplesInfo( "HuggingFaceM4/Idefics3-8B-Llama3", @@ -674,8 +679,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31", ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( - "lightonai/LightOnOCR-1B", - is_available_online=False, + "lightonai/LightOnOCR-1B-1025" ), "Llama4ForConditionalGeneration": _HfExamplesInfo( "meta-llama/Llama-4-Scout-17B-16E-Instruct", @@ -779,8 +783,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { "ministral-3": "mistralai/Ministral-3-3B-Instruct-2512", }, tokenizer_mode="mistral", - # TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available. - is_available_online=False, ), "QwenVLForConditionalGeneration": _HfExamplesInfo( "Qwen/Qwen-VL", @@ -843,7 +845,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] - "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), + "WhisperForConditionalGeneration": _HfExamplesInfo( + "openai/whisper-large-v3-turbo", + extras={"v3": "openai/whisper-large-v3"}, + ), # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), } @@ -886,6 +891,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EagleMistralLarge3ForCausalLM": _HfExamplesInfo( "mistralai/Mistral-Large-3-675B-Instruct-2512", speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle", + # TODO: revert once figuring out OOM in CI is_available_online=False, ), "LlamaForCausalLMEagle3": _HfExamplesInfo( diff --git a/tests/multimodal/test_sparse_tensor_validation_unit.py b/tests/multimodal/test_sparse_tensor_validation_unit.py new file mode 100644 index 0000000000000..2eec8ea8283a2 --- /dev/null +++ b/tests/multimodal/test_sparse_tensor_validation_unit.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for sparse tensor validation. + +Simple, fast unit tests that can run without server fixtures. +Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v +""" + +import io + +import pytest +import torch + + +class TestSparseTensorValidationContextManager: + """Test that torch.sparse.check_sparse_tensor_invariants() works as expected.""" + + def test_valid_sparse_tensor_passes(self): + """Valid sparse tensors should pass validation.""" + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + shape = (2, 2) + + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + dense = tensor.to_dense() + + assert dense.shape == shape + + def test_out_of_bounds_indices_rejected(self): + """Sparse tensors with out-of-bounds indices should be rejected.""" + indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2 + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + assert ( + "index" in str(exc_info.value).lower() + or "bound" in str(exc_info.value).lower() + ) + + def test_negative_indices_rejected(self): + """Sparse tensors with negative indices should be rejected.""" + indices = torch.tensor([[-1], [0]]) + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + def test_without_context_manager_allows_invalid(self): + """ + WITHOUT validation, invalid tensors may not immediately error. + + This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate + by default, which can lead to memory corruption. + """ + indices = torch.tensor([[100], [100]]) # Way out of bounds + values = torch.tensor([1.0]) + shape = (2, 2) + + # Without validation context, this might create an invalid tensor + # (actual behavior depends on PyTorch version) + tensor = torch.sparse_coo_tensor(indices, values, shape) + + # The tensor object is created, but it's invalid + assert tensor.is_sparse + + +class TestTorchLoadWithValidation: + """Test torch.load() with sparse tensor validation.""" + + def test_load_valid_sparse_tensor_with_validation(self): + """Valid sparse tensors should load successfully with validation.""" + # Create and save a valid sparse tensor + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + dense = loaded.to_dense() + + assert dense.shape == (2, 2) + + def test_load_invalid_sparse_tensor_rejected(self): + """Invalid sparse tensors should be caught when loaded with validation.""" + # Create an invalid sparse tensor (out of bounds) + indices = torch.tensor([[10], [10]]) + values = torch.tensor([1.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation - should fail on to_dense() + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + loaded.to_dense() + + def test_load_dense_tensor_unaffected(self): + """Dense tensors should work normally with the validation context.""" + # Create and save a dense tensor + tensor = torch.randn(10, 20) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation (should have no effect on dense tensors) + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + + assert loaded.shape == (10, 20) + assert not loaded.is_sparse + + +if __name__ == "__main__": + # Allow running directly for quick testing + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 8dd4551ff4b96..a43d2abfdd8b8 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -10,9 +10,9 @@ import pytest from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( - "This test only runs on Blackwell GPUs (SM100).", allow_module_level=True + "This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True ) diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 01592fd0782a9..d6da723f80b08 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -5,7 +5,7 @@ import pytest from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer parser_name = "mistral" diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 695312a0cadfe..a020fb8e97161 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.reasoning import ReasoningParser -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: diff --git a/tests/standalone_tests/lazy_imports.py b/tests/standalone_tests/lazy_imports.py index ddcdd2a51ab9f..fff5c54f276d3 100644 --- a/tests/standalone_tests/lazy_imports.py +++ b/tests/standalone_tests/lazy_imports.py @@ -5,9 +5,6 @@ # The utility function cannot be placed in `vllm.utils` # this needs to be a standalone script import sys -from contextlib import nullcontext - -from vllm_test_utils import BlameResult, blame # List of modules that should not be imported too early. # Lazy import `torch._inductor.async_compile` to avoid creating @@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame # `cv2` can easily mess up the environment. module_names = ["torch._inductor.async_compile", "cv2"] +# set all modules in `module_names` to be None. +# if we import any modules during `import vllm`, there would be a +# hard error and nice stacktrace on the first import. +for module_name in module_names: + sys.modules[module_name] = None # type: ignore[assignment] -def any_module_imported(): - return any(module_name in sys.modules for module_name in module_names) - - -# In CI, we only check finally if the module is imported. -# If it is indeed imported, we can rerun the test with `use_blame=True`, -# which will trace every function call to find the first import location, -# and help find the root cause. -# We don't run it in CI by default because it is slow. -use_blame = False -context = blame(any_module_imported) if use_blame else nullcontext() -with context as result: - import vllm # noqa - -if use_blame: - assert isinstance(result, BlameResult) - print(f"the first import location is:\n{result.trace_stack}") - -assert not any_module_imported(), ( - f"Some the modules in {module_names} are imported. To see the first" - f" import location, run the test with `use_blame=True`." -) +import vllm # noqa diff --git a/tests/test_inputs.py b/tests/test_inputs.py index c4339827de8b6..073be24a4a072 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config pytestmark = pytest.mark.cpu_test @@ -34,6 +34,13 @@ INPUTS_SLICES = [ ] +# Test that a nested mixed-type list of lists raises a TypeError. +@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]]) +def test_invalid_input_raise_type_error(invalid_input): + with pytest.raises(TypeError): + parse_raw_prompts(invalid_input) + + def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): parse_raw_prompts([]) @@ -108,7 +115,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_config(model_config) + tokenizer = cached_tokenizer_from_config(model_config) input_preprocessor = InputPreprocessor(model_config, tokenizer) # HF processor adds sep token diff --git a/tests/tokenizers_/test_basic.py b/tests/tokenizers_/test_basic.py index b152227a5a50f..0510261eacde7 100644 --- a/tests/tokenizers_/test_basic.py +++ b/tests/tokenizers_/test_basic.py @@ -3,38 +3,39 @@ from typing import _get_protocol_attrs # type: ignore import pytest -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer def _get_missing_attrs(obj: object, target: type): return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)] +def _assert_tokenizer_like(tokenizer: object): + missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike) + assert not missing_attrs, f"Missing attrs: {missing_attrs}" + + def test_tokenizer_like_protocol(): - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=False), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=False) + assert isinstance(tokenizer, PreTrainedTokenizer) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=True), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=True) + assert isinstance(tokenizer, PreTrainedTokenizerFast) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer( - "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" - ), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer( + "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" + ) + assert isinstance(tokenizer, MistralTokenizer) + _assert_tokenizer_like(tokenizer) @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) diff --git a/tests/tokenizers_/test_detokenize.py b/tests/tokenizers_/test_detokenize.py index ae1d6b0956722..d307993d04df9 100644 --- a/tests/tokenizers_/test_detokenize.py +++ b/tests/tokenizers_/test_detokenize.py @@ -8,7 +8,7 @@ import pytest from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import ( FastIncrementalDetokenizer, diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 7e795350d64c8..546f38b078dde 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -2,7 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from vllm.tokenizers import TokenizerLike, TokenizerRegistry, get_tokenizer +import pytest + +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.registry import ( + TokenizerRegistry, + get_tokenizer, + resolve_tokenizer_args, +) class TestTokenizer(TokenizerLike): @@ -40,10 +47,22 @@ class TestTokenizer(TokenizerLike): return True +@pytest.mark.parametrize("runner_type", ["generate", "pooling"]) +def test_resolve_tokenizer_args_idempotent(runner_type): + tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args( + "facebook/opt-125m", + runner_type=runner_type, + ) + + assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args( + tokenizer_name, *args, **kwargs + ) + + def test_customized_tokenizer(): TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__) - tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc") + tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.path_or_repo_id == "abc" assert tokenizer.bos_token_id == 0 diff --git a/tests/tool_parsers/__init__.py b/tests/tool_parsers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_parsers/test_deepseekv31_tool_parser.py similarity index 96% rename from tests/tool_use/test_deepseekv31_tool_parser.py rename to tests/tool_parsers/test_deepseekv31_tool_parser.py index 8beb7739b6081..69a4cc8b989c5 100644 --- a/tests/tool_use/test_deepseekv31_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv31_tool_parser.py @@ -3,10 +3,10 @@ import pytest -from vllm.entrypoints.openai.tool_parsers.deepseekv31_tool_parser import ( +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.deepseekv31_tool_parser import ( DeepSeekV31ToolParser, ) -from vllm.tokenizers import get_tokenizer MODEL = "deepseek-ai/DeepSeek-V3.1" diff --git a/tests/tool_use/test_ernie45_moe_tool_parser.py b/tests/tool_parsers/test_ernie45_moe_tool_parser.py similarity index 99% rename from tests/tool_use/test_ernie45_moe_tool_parser.py rename to tests/tool_parsers/test_ernie45_moe_tool_parser.py index 92f86de23267b..533bd1ec3dfff 100644 --- a/tests/tool_use/test_ernie45_moe_tool_parser.py +++ b/tests/tool_parsers/test_ernie45_moe_tool_parser.py @@ -13,9 +13,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally +from vllm.tool_parsers.ernie45_tool_parser import Ernie45ToolParser # Use a common model that is likely to be available MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_parsers/test_glm4_moe_tool_parser.py similarity index 99% rename from tests/tool_use/test_glm4_moe_tool_parser.py rename to tests/tool_parsers/test_glm4_moe_tool_parser.py index 753b3f1c23adf..52f5a9198e9b4 100644 --- a/tests/tool_use/test_glm4_moe_tool_parser.py +++ b/tests/tool_parsers/test_glm4_moe_tool_parser.py @@ -7,12 +7,10 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.glm4_moe_tool_parser import ( +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.glm4_moe_tool_parser import ( Glm4MoeModelToolParser, ) -from vllm.tokenizers import get_tokenizer - -pytestmark = pytest.mark.cpu_test pytest.skip("skip glm4_moe parser test", allow_module_level=True) # Use a common model that is likely to be available diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_parsers/test_jamba_tool_parser.py similarity index 98% rename from tests/tool_use/test_jamba_tool_parser.py rename to tests/tool_parsers/test_jamba_tool_parser.py index 9036bd32dd704..ccad16ae2f6b6 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_parsers/test_jamba_tool_parser.py @@ -9,11 +9,9 @@ import pytest from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.jamba_tool_parser import JambaToolParser MODEL = "ai21labs/Jamba-tiny-dev" diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_parsers/test_kimi_k2_tool_parser.py similarity index 99% rename from tests/tool_use/test_kimi_k2_tool_parser.py rename to tests/tool_parsers/test_kimi_k2_tool_parser.py index 1558a9c3e01f2..d02f53c34b455 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_parsers/test_kimi_k2_tool_parser.py @@ -7,10 +7,8 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser from vllm.tokenizers import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser # Use a common model that is likely to be available MODEL = "moonshotai/Kimi-K2-Instruct" diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_parsers/test_minimax_tool_parser.py similarity index 99% rename from tests/tool_use/test_minimax_tool_parser.py rename to tests/tool_parsers/test_minimax_tool_parser.py index dda63f984a832..28cfc4ea7a175 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_parsers/test_minimax_tool_parser.py @@ -12,10 +12,8 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.minimax_tool_parser import MinimaxToolParser from vllm.tokenizers import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.minimax_tool_parser import MinimaxToolParser # Use a common model that is likely to be available MODEL = "MiniMaxAi/MiniMax-M1-40k" diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py similarity index 99% rename from tests/tool_use/test_mistral_tool_parser.py rename to tests/tool_parsers/test_mistral_tool_parser.py index 2dd0399cb8eeb..9400a67267f4c 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -12,13 +12,10 @@ from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser -from vllm.tokenizers import ( - MistralTokenizer, - TokenizerLike, - get_tokenizer, -) +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser @pytest.fixture(scope="module") diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_parsers/test_openai_tool_parser.py similarity index 99% rename from tests/tool_use/test_openai_tool_parser.py rename to tests/tool_parsers/test_openai_tool_parser.py index 6537f281c0e1b..44b8c92745e91 100644 --- a/tests/tool_use/test_openai_tool_parser.py +++ b/tests/tool_parsers/test_openai_tool_parser.py @@ -15,8 +15,8 @@ from openai_harmony import ( ) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.openai_tool_parser import OpenAIToolParser from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser MODEL = "gpt2" diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_parsers/test_qwen3coder_tool_parser.py similarity index 99% rename from tests/tool_use/test_qwen3coder_tool_parser.py rename to tests/tool_parsers/test_qwen3coder_tool_parser.py index 5a56768805fdf..3a0a612d7fbfd 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_parsers/test_qwen3coder_tool_parser.py @@ -13,14 +13,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( - Qwen3CoderToolParser, -) -from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.qwen3coder_tool_parser import ( + Qwen3CoderToolParser, +) +from vllm.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_parsers/test_seed_oss_tool_parser.py similarity index 99% rename from tests/tool_use/test_seed_oss_tool_parser.py rename to tests/tool_parsers/test_seed_oss_tool_parser.py index 8795c35a1347f..c7f595830f34b 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_parsers/test_seed_oss_tool_parser.py @@ -14,11 +14,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.seed_oss_tool_parser import SeedOssToolParser # Use a common model that is likely to be available MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_parsers/test_xlam_tool_parser.py similarity index 99% rename from tests/tool_use/test_xlam_tool_parser.py rename to tests/tool_parsers/test_xlam_tool_parser.py index 3098fda036a81..380792a9926a4 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_parsers/test_xlam_tool_parser.py @@ -12,11 +12,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.xlam_tool_parser import xLAMToolParser # Use a common model that is likely to be available MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r" diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index d5572cfbebe3c..35ed8d215f73a 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionToolsParam, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools +from vllm.tool_parsers.utils import get_json_schema_from_tools pytestmark = pytest.mark.cpu_test diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index f08e2f480e30f..734819fcdca83 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches( num_tokens, batch_spec.batch_size, split_point=split_point, + num_ubatches=2, ) assert ubatch_slices is not None and len(ubatch_slices) == 2 diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index faace3473a281..4529c2cfc29b6 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ) # Call the function - result = make_local_attention_virtual_batches( + result, _ = make_local_attention_virtual_batches( attn_chunk_size, common_attn_metadata, block_size ) diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 1c45e7fe366ff..7a58e1c9bad03 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 40b9d1fe850c6..bc9674ee86cf8 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -76,6 +76,8 @@ def sample_json_schema(): }, "required": ["name", "age", "skills", "grade", "email", "work_history"], "additionalProperties": False, + "minProperties": 1, + "maxProperties": 10, } @@ -96,6 +98,9 @@ def unsupported_json_schema(): }, "required": ["score", "tags"], "additionalProperties": False, + "patternProperties": { + "^score$": {"type": "integer"}, + }, } diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 53da09cfbc21d..66804fa671c7c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -461,7 +461,7 @@ class TestNixlHandshake: metadata = NixlConnectorMetadata() if num_xfers > 0: num_xfers -= 1 - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], kv_transfer_params={ @@ -532,7 +532,7 @@ class TestNixlHandshake: vllm_config, connector.engine_id ) metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id="id", local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -588,7 +588,7 @@ class TestNixlHandshake: metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=f"id_{i}", local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init): # Create transfer metadata request_id = "test_req_for_stats" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init): request_id = "test_handshake_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init): request_id = "test_transfer_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[7, 8, 9], kv_transfer_params={ diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index a248104e16d2d..3516c0013879d 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -9,7 +9,7 @@ import torch from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec -from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers BACKENDS_TO_TEST = [FlashAttentionBackend] @@ -82,7 +82,7 @@ def test_transfer( # create handler cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size - handler = CpuGpuOffloadingHandler( + handlers = CpuGpuOffloadingHandlers( attn_backends=attn_backends, gpu_block_size=gpu_block_size, cpu_block_size=cpu_block_size, @@ -112,8 +112,7 @@ def test_transfer( # set transfer direction if gpu_to_cpu: - src_kv_caches = handler.gpu_tensors - dst_kv_caches = handler.cpu_tensors + handler = handlers.gpu_to_cpu_handler src_spec_class = GPULoadStoreSpec dst_spec_class = CPULoadStoreSpec src_blocks = gpu_blocks @@ -122,8 +121,7 @@ def test_transfer( dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block else: - src_kv_caches = handler.cpu_tensors - dst_kv_caches = handler.gpu_tensors + handler = handlers.cpu_to_gpu_handler src_spec_class = CPULoadStoreSpec dst_spec_class = GPULoadStoreSpec src_blocks = cpu_blocks @@ -144,12 +142,12 @@ def test_transfer( dst_spec = dst_spec_class(dst_blocks) # clone src and dst tensors before transfer - orig_src_caches = [x.clone() for x in src_kv_caches] - orig_dst_caches = [x.clone() for x in dst_kv_caches] + orig_src_caches = [x.clone() for x in handler.src_tensors] + orig_dst_caches = [x.clone() for x in handler.dst_tensors] # call transfer function assert handler.transfer_async(1, (src_spec, dst_spec)) - assert set(handler.transfer_events.keys()) == {1} + assert set({x[0] for x in handler._transfers}) == {1} # wait for transfer to complete end_time = time.time() + 10 @@ -161,15 +159,15 @@ def test_transfer( time.sleep(0.1) # verify src tensors did not change - for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): + for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors): assert torch.equal(orig_tensor, tensor) # verify dst tensors for dst_block in range(dst_size_in_gpu_blocks): src_block_candidate = dst_to_src.get(dst_block) for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( - src_kv_caches, - dst_kv_caches, + handler.src_tensors, + handler.dst_tensors, orig_dst_caches, handler.kv_dim_before_num_blocks, ): diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 513a21dd6bb39..c026ab0e4e785 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -44,8 +44,6 @@ def unsupported_array_schemas(): @pytest.fixture def unsupported_object_schemas(): return [ - {"type": "object", "minProperties": 1}, - {"type": "object", "maxProperties": 5}, {"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}}, {"type": "object", "patternProperties": {"^S": {"type": "string"}}}, ] @@ -79,6 +77,8 @@ def supported_schema(): }, }, }, + "minProperties": 1, + "maxProperties": 100, } diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 010817e79a936..c32bf04c71c1f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -642,48 +642,130 @@ _OPS_REGISTERED = False class rocm_aiter_ops: + """ROCm AITER operations wrapper for AMD GPU acceleration in vLLM. + + This class centralizes the import and registration of AITER ops, + and provides a unified interface for checking if AITER is enabled. + Operations are only available on supported gfx9 + architectures when aiter is installed. + + The class uses environment variables to control which features are enabled, + allowing fine-grained control over which AITER optimizations are used. + + Environment Variables: + VLLM_ROCM_USE_AITER: Main toggle for all AITER operations. + VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops. + VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations. + VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops. + VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops. + VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen. + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention. + VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply. + VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM. + VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings. + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion. + VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM. + + Note: + The environment variables are assigned when the module is imported, + so you can't change the environment variables after the module is imported. + This is done out of performance consideration. Accessing environment variables + is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067 + so we don't want to do it repeatedly, especially in the hot path (the forward pass). + You can call the refresh_env_variables() function to reload the env variables + after monkey patching the env variables in the unit test. + + Check Functions: + All check functions (is_*_enabled) are decorated with @if_aiter_supported, + which verifies: (1) platform is ROCm, (2) device arch is gfx9, and + (3) aiter library is installed. The check function then also verifies + the corresponding environment variable is enabled. + i.e. ___ + is_enabled() == current_platform.is_rocm() and | checked by + current_platform.is_on_gfx9() and | @if_aiter_supported + IS_AITER_FOUND and _______________| + cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` + + Example: + from vllm._aiter_ops import rocm_aiter_ops + + # Check if aiter is enabled before using operations + if rocm_aiter_ops.is_enabled(): + result = rocm_aiter_ops.rms_norm(x, weight, epsilon) + + Operations: + - RMS normalization: rms_norm, rms_norm2d_with_add + - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale + - Fused MoE: fused_moe, asm_moe_tkw1 + - Routing: topk_softmax, biased_grouped_topk, grouped_topk + - MLA decode: mla_decode_fwd + - Quantization: per_tensor_quant, per_token_quant, group_fp8_quant + - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale + """ + + # Check if the env variable is set _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA - _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod + def refresh_env_variables(cls): + """ + Since the environment variables are assigned when the module is imported, + This is a helper function to reload all the env variables from + the environment variables. + for example, after monkey patching the env variables in the unit test, + you can call this function to reload the env variables. + """ + cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: - """Verifies device specs and availability of aiter main env variable.""" return cls._AITER_ENABLED @classmethod @if_aiter_supported def is_linear_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod @if_aiter_supported def is_linear_fp8_enaled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls.is_linear_enabled() @classmethod @if_aiter_supported def is_rmsnorm_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod @@ -694,25 +776,16 @@ class rocm_aiter_ops: @classmethod @if_aiter_supported def is_mla_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod @if_aiter_supported def is_mha_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MHA_ENABLED - @classmethod - @if_aiter_supported - def is_pa_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" - return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED - @classmethod @if_aiter_supported def is_triton_unified_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 52a58a082683d..2319655008c50 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -498,15 +498,15 @@ def awq_dequantize( def awq_gemm( input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, split_k_iters: int, ) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton - return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) - return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) # gptq @@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _awq_gemm_fake( input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, split_k_iters: torch.SymInt, ) -> torch.Tensor: num_in_feats = input.size(0) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c77fc0fad0038..7ef77db8fbb5b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from collections.abc import Callable +import functools from typing import cast import torch @@ -16,7 +16,9 @@ from vllm.attention.backends.abstract import ( MLAAttentionImpl, ) from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend from vllm.attention.selector import get_attn_backend +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config @@ -47,58 +49,9 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, ) -if current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx9 -else: - on_gfx9 = lambda *args, **kwargs: False - - -FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) -def maybe_get_vit_flash_attn_backend( - attn_backend: AttentionBackendEnum, - attn_backend_override: AttentionBackendEnum | None = None, -) -> tuple[AttentionBackendEnum, Callable | None]: - if current_platform.is_rocm(): - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = AttentionBackendEnum.ROCM_AITER_FA - elif ( - attn_backend_override is None - and on_gfx9() - and attn_backend == AttentionBackendEnum.FLASH_ATTN - ): - pass - else: - return AttentionBackendEnum.TORCH_SDPA, None - elif current_platform.is_cuda(): - pass - elif current_platform.is_xpu(): - assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( - "XPU platform only supports FLASH_ATTN as vision attention backend." - ) - pass - else: - return AttentionBackendEnum.TORCH_SDPA, None - - if attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - }: - if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - try: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func - except ImportError: - flash_attn_varlen_func = None - else: - flash_attn_varlen_func = None - - return attn_backend, flash_attn_varlen_func - - def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, @@ -494,29 +447,15 @@ class MultiHeadAttention(nn.Module): attn_backend_override = None if multimodal_config is not None: attn_backend_override = multimodal_config.mm_encoder_attn_backend - backend = get_vit_attn_backend( + + self.attn_backend = get_vit_attn_backend( head_size=head_size, dtype=dtype, attn_backend_override=attn_backend_override, ) - self.attn_backend = ( - backend - if backend - in { - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.PALLAS, - AttentionBackendEnum.ROCM_AITER_FA, - AttentionBackendEnum.FLASH_ATTN, - } - else AttentionBackendEnum.TORCH_SDPA - ) - - self.attn_backend, self._flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) self.is_flash_attn_backend = self.attn_backend in { @@ -524,6 +463,17 @@ class MultiHeadAttention(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, } + self.fa_version = None + if ( + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + and current_platform.is_cuda() + ): + self.fa_version = get_flash_attn_version() + assert self._flash_attn_varlen_func is not None + self._flash_attn_varlen_func = functools.partial( + self._flash_attn_varlen_func, fa_version=self.fa_version + ) + logger.info_once( f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." ) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 0ced0028ded9e..7e3794d408332 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -4,7 +4,7 @@ import functools import torch -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -51,11 +51,19 @@ def create_chunked_local_attention_backend( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, - ) -> AttentionMetadata: - common_attn_metadata = make_local_attention_virtual_batches( + ): + cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( attention_chunk_size, common_attn_metadata, block_size ) - return super().build(common_prefix_len, common_attn_metadata, fast_build) + metadata = super().build(common_prefix_len, cm, fast_build) + metadata.make_virtual_batches_block_table = make_virtual_batches_block_table + return metadata + + def update_block_table( + self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor + ): + blk_table = metadata.make_virtual_batches_block_table(blk_table) + return super().update_block_table(metadata, blk_table, slot_mapping) attn_backend = subclass_attention_backend( name_prefix=prefix, diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py new file mode 100644 index 0000000000000..c9107ebcab856 --- /dev/null +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.ops.vit_attn_wrappers import ( + vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, +) +from vllm.config import MultiModalConfig +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.models.vision import get_vit_attn_backend + +logger = init_logger(__name__) + + +def maybe_get_vit_flash_attn_backend( + attn_backend: AttentionBackendEnum | None, +) -> Callable | None: + # At this point, + # we already have the attn_backend, + # overriding logic is done in the platform-specific implementation. + # so we don't need to override backend here. + # Just return the attn_backend and flash_attn_varlen_func. + + if attn_backend == AttentionBackendEnum.FLASH_ATTN: + from vllm.attention.utils.fa_utils import flash_attn_varlen_func + elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + flash_attn_varlen_func = None + + # if attn_backend is TORCH_SDPA, + # it will reach here and the flash_attn_varlen_func will be None. + return flash_attn_varlen_func + + +@CustomOp.register("mm_encoder_attn") +class MMEncoderAttention(CustomOp): + """Multi-headed attention without any cache, used for multimodal encoder.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float | None = None, + num_kv_heads: int | None = None, + prefix: str = "", + multimodal_config: MultiModalConfig | None = None, + ) -> None: + """ + Args: + num_heads: number of attention heads per partition. + head_size: hidden_size per attention head. + scale: scale factor. + num_kv_heads: number of kv heads. + prefix: This has no effect, it is only here to make it easier to + swap between Attention and MultiHeadAttention + multimodal_config: configs for multi-modal. + """ + super().__init__() + + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.layer_name = prefix + + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads ({self.num_heads}) is not " + f"divisible by num_kv_heads ({self.num_kv_heads})" + ) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + + # Try to get vision attention backend from multimodal_config. + attn_backend_override = None + if multimodal_config is not None: + attn_backend_override = multimodal_config.mm_encoder_attn_backend + + # Get device-specific vision attention backend. + self.attn_backend = get_vit_attn_backend( + head_size=head_size, + dtype=dtype, + attn_backend_override=attn_backend_override, + ) + + self.is_flash_attn_backend = self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } + + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, + ) + + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") + + @classmethod + def enabled(cls) -> bool: + return True + + def reshape_qkv_to_4d( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bsz: int, + q_len: int, + kv_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reshape query, key, value to 4D tensors: + (batch_size, seq_len, num_heads, head_size) + """ + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + + return query, key, value + + def reshape_qkv_to_3d( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bsz: int, + q_len: int, + kv_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reshape query, key, value to 3D tensors: + (batch_size * seq_len, num_heads, head_size) + """ + query = query.view(bsz * q_len, self.num_heads, self.head_size) + key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=1) + value = torch.repeat_interleave(value, num_repeat, dim=1) + + return query, key, value + + def _forward_sdpa( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + ) -> torch.Tensor: + # TODO(Isotr0py): Migrate MultiHeadAttention + assert cu_seqlens is not None + + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + + query, key, value = self.reshape_qkv_to_4d( + query, key, value, bsz, q_len, kv_len + ) + + output = vit_torch_sdpa_wrapper( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens, + ) + return output + + def _forward_fa( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.flash_attn_varlen_func is not None, ( + "Flash attention function is not set." + ) + # # TODO(Isotr0py): Migrate MultiHeadAttention + assert cu_seqlens is not None and max_seqlen is not None + + bsz = query.shape[0] + + output = vit_flash_attn_wrapper( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=bsz, + is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), + ) + return output + + def forward_native( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + return self._forward_sdpa(query, key, value, cu_seqlens) + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + if self.is_flash_attn_backend: + return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: + return self._forward_sdpa(query, key, value, cu_seqlens) + else: + raise ValueError( + f"Unsupported multi-modal encoder attention backend for CUDA: " + f"{self.attn_backend}." + ) + + def forward_cpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + return self._forward_sdpa(query, key, value, cu_seqlens) + + def forward_xpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.is_flash_attn_backend, ( + "XPU only supports FLASH_ATTN for vision attention." + ) + return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + + def forward_tpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.attn_backend == AttentionBackendEnum.PALLAS, ( + f"MMEncoderAttention on TPU only supports PALLAS backend, " + f"but got {self.attn_backend}." + ) + if cu_seqlens is None: + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + from torch_xla.experimental.custom_kernel import flash_attention + + out = flash_attention(query, key, value, sm_scale=self.scale) + out = out.transpose(1, 2) + return out + logger.warning_once( + "PALLAS backend with cu_seqlens is not supported for ViT yet. ", + "Falling back to SDPA implementation.", + ) + return self._forward_sdpa(query, key, value, cu_seqlens) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 9036c2b801949..46c7d83dfa5c2 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper( dropout_p=0.0, causal=False, ) - context_layer = einops.rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) return context_layer @@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake( batch_size: int, is_rocm_aiter: bool, ) -> torch.Tensor: - b, s, h, d = q.shape - return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + return torch.empty_like(q) direct_register_custom_op( @@ -106,7 +103,6 @@ def torch_sdpa_wrapper( output_i = einops.rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() return context_layer @@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake( v: torch.Tensor, cu_seqlens: torch.Tensor, ) -> torch.Tensor: - b, s, h, d = q.shape - return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + return torch.empty_like(q) direct_register_custom_op( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index bbf95ff009001..e66f698add99d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import cache -from typing import cast, get_args +from typing import NamedTuple, cast, get_args import torch -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.registry import ( MAMBA_TYPE_TO_BACKEND_MAP, MambaAttentionBackendEnum, @@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) +class AttentionSelectorConfig(NamedTuple): + head_size: int + dtype: torch.dtype + kv_cache_dtype: CacheDType | None + block_size: int | None + use_mla: bool = False + has_sink: bool = False + use_sparse: bool = False + use_mm_prefix: bool = False + attn_type: str = AttentionType.DECODER + + def __repr__(self): + return ( + f"AttentionSelectorConfig(head_size={self.head_size}, " + f"dtype={self.dtype}, " + f"kv_cache_dtype={self.kv_cache_dtype}, " + f"block_size={self.block_size}, " + f"use_mla={self.use_mla}, " + f"has_sink={self.has_sink}, " + f"use_sparse={self.use_sparse}, " + f"use_mm_prefix={self.use_mm_prefix}, " + f"attn_type={self.attn_type})" + ) + + def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -43,8 +68,7 @@ def get_attn_backend( vllm_config = get_current_vllm_config() backend_enum = vllm_config.attention_config.backend - return _cached_get_attn_backend( - backend=backend_enum, + attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), @@ -53,36 +77,25 @@ def get_attn_backend( has_sink=has_sink, use_sparse=use_sparse, use_mm_prefix=use_mm_prefix, - attn_type=attn_type, + attn_type=attn_type or AttentionType.DECODER, + ) + + return _cached_get_attn_backend( + backend=backend_enum, + attn_selector_config=attn_selector_config, ) @cache def _cached_get_attn_backend( backend, - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: CacheDType | None, - block_size: int | None, - use_mla: bool = False, - has_sink: bool = False, - use_sparse: bool = False, - use_mm_prefix: bool = False, - attn_type: str | None = None, + attn_selector_config: AttentionSelectorConfig, ) -> type[AttentionBackend]: from vllm.platforms import current_platform attention_cls = current_platform.get_attn_backend_cls( backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - attn_type, + attn_selector_config=attn_selector_config, ) if not attention_cls: raise ValueError( diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 254e4d35e5350..f5d8ea5a975a9 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -235,7 +235,9 @@ async def get_request( def calculate_metrics_for_embeddings( - outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float] + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float], ) -> EmbedBenchmarkMetrics: """Calculate the metrics for the embedding requests. diff --git a/vllm/benchmarks/startup.py b/vllm/benchmarks/startup.py new file mode 100644 index 0000000000000..086f7bf62f838 --- /dev/null +++ b/vllm/benchmarks/startup.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the cold and warm startup time of vLLM models. + +This script measures total startup time (including model loading, compilation, +and cache operations) for both cold and warm scenarios: +- Cold startup: Fresh start with no caches (temporary cache directories) +- Warm startup: Using cached compilation and model info +""" + +import argparse +import dataclasses +import json +import multiprocessing +import os +import shutil +import tempfile +import time +from contextlib import contextmanager +from typing import Any + +import numpy as np +from tqdm import tqdm + +from vllm.benchmarks.lib.utils import ( + convert_to_pytorch_benchmark_format, + write_to_json, +) +from vllm.engine.arg_utils import EngineArgs + + +@contextmanager +def cold_startup(): + """ + Context manager to measure cold startup time: + 1. Uses a temporary directory for vLLM cache to avoid any pollution + between cold startup iterations. + 2. Uses inductor's fresh_cache to clear torch.compile caches. + """ + from torch._inductor.utils import fresh_cache + + # Use temporary directory for caching to avoid any pollution between cold startups + original_cache_root = os.environ.get("VLLM_CACHE_ROOT") + temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_") + try: + os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir + with fresh_cache(): + yield + finally: + # Clean up temporary cache directory + shutil.rmtree(temp_cache_dir, ignore_errors=True) + if original_cache_root: + os.environ["VLLM_CACHE_ROOT"] = original_cache_root + else: + os.environ.pop("VLLM_CACHE_ROOT", None) + + +def run_startup_in_subprocess(engine_args_dict, result_queue): + """ + Run LLM startup in a subprocess and return timing metrics via a queue. + This ensures complete isolation between iterations. + """ + try: + # Import inside the subprocess to avoid issues with forking + from vllm import LLM + from vllm.engine.arg_utils import EngineArgs + + engine_args = EngineArgs(**engine_args_dict) + + # Measure total startup time + start_time = time.perf_counter() + + llm = LLM(**dataclasses.asdict(engine_args)) + + total_startup_time = time.perf_counter() - start_time + + # Extract compilation time if available + compilation_time = 0.0 + if hasattr(llm.llm_engine, "vllm_config"): + vllm_config = llm.llm_engine.vllm_config + if ( + hasattr(vllm_config, "compilation_config") + and vllm_config.compilation_config is not None + ): + compilation_time = vllm_config.compilation_config.compilation_time + + result_queue.put( + { + "total_startup_time": total_startup_time, + "compilation_time": compilation_time, + } + ) + + except Exception as e: + result_queue.put(None) + result_queue.put(str(e)) + + +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: + base_name = os.path.splitext(args.output_json)[0] + + cold_startup_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_cold_startup_time": results["avg_cold_startup_time"], + }, + extra_info={ + "cold_startup_times": results["cold_startup_times"], + "cold_startup_percentiles": results["cold_startup_percentiles"], + }, + ) + if cold_startup_records: + write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records) + + cold_compilation_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_cold_compilation_time": results["avg_cold_compilation_time"], + }, + extra_info={ + "cold_compilation_times": results["cold_compilation_times"], + "cold_compilation_percentiles": results["cold_compilation_percentiles"], + }, + ) + if cold_compilation_records: + write_to_json( + f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records + ) + + warm_startup_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_warm_startup_time": results["avg_warm_startup_time"], + }, + extra_info={ + "warm_startup_times": results["warm_startup_times"], + "warm_startup_percentiles": results["warm_startup_percentiles"], + }, + ) + if warm_startup_records: + write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records) + + warm_compilation_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_warm_compilation_time": results["avg_warm_compilation_time"], + }, + extra_info={ + "warm_compilation_times": results["warm_compilation_times"], + "warm_compilation_percentiles": results["warm_compilation_percentiles"], + }, + ) + if warm_compilation_records: + write_to_json( + f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records + ) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-iters-cold", + type=int, + default=5, + help="Number of cold startup iterations.", + ) + parser.add_argument( + "--num-iters-warmup", + type=int, + default=3, + help="Number of warmup iterations before benchmarking warm startups.", + ) + parser.add_argument( + "--num-iters-warm", + type=int, + default=5, + help="Number of warm startup iterations.", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the startup time results in JSON format.", + ) + + parser = EngineArgs.add_cli_args(parser) + return parser + + +def main(args: argparse.Namespace): + # Set multiprocessing start method to 'spawn' for clean process isolation + # This ensures each subprocess starts fresh without inheriting state + multiprocessing.set_start_method("spawn", force=True) + + engine_args = EngineArgs.from_cli_args(args) + + def create_llm_and_measure_startup(): + """ + Create LLM instance in a subprocess and measure startup time. + Returns timing metrics, using subprocess for complete isolation. + """ + # Convert engine_args to dictionary for pickling + engine_args_dict = dataclasses.asdict(engine_args) + + # Create a queue for inter-process communication + result_queue = multiprocessing.Queue() + process = multiprocessing.Process( + target=run_startup_in_subprocess, + args=( + engine_args_dict, + result_queue, + ), + ) + process.start() + process.join() + + if not result_queue.empty(): + result = result_queue.get() + if result is None: + if not result_queue.empty(): + error_msg = result_queue.get() + raise RuntimeError(f"Subprocess failed: {error_msg}") + else: + raise RuntimeError("Subprocess failed with unknown error") + return result + else: + raise RuntimeError("Subprocess did not return a result") + + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n") + + print("Measuring cold startup time...\n") + cold_startup_times = [] + cold_compilation_times = [] + for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"): + with cold_startup(): + metrics = create_llm_and_measure_startup() + cold_startup_times.append(metrics["total_startup_time"]) + cold_compilation_times.append(metrics["compilation_time"]) + + # Warmup for warm startup + print("\nWarming up for warm startup measurement...\n") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + create_llm_and_measure_startup() + + print("\nMeasuring warm startup time...\n") + warm_startup_times = [] + warm_compilation_times = [] + for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"): + metrics = create_llm_and_measure_startup() + warm_startup_times.append(metrics["total_startup_time"]) + warm_compilation_times.append(metrics["compilation_time"]) + + # Calculate statistics + cold_startup_array = np.array(cold_startup_times) + cold_compilation_array = np.array(cold_compilation_times) + warm_startup_array = np.array(warm_startup_times) + warm_compilation_array = np.array(warm_compilation_times) + + avg_cold_startup = np.mean(cold_startup_array) + avg_cold_compilation = np.mean(cold_compilation_array) + avg_warm_startup = np.mean(warm_startup_array) + avg_warm_compilation = np.mean(warm_compilation_array) + + percentages = [10, 25, 50, 75, 90, 99] + cold_startup_percentiles = np.percentile(cold_startup_array, percentages) + cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages) + warm_startup_percentiles = np.percentile(warm_startup_array, percentages) + warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages) + + print("\n" + "=" * 60) + print("STARTUP TIME BENCHMARK RESULTS") + print("=" * 60) + + # Cold startup statistics + print("\nCOLD STARTUP:") + print(f"Avg total startup time: {avg_cold_startup:.2f} seconds") + print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds") + print("Startup time percentiles:") + for percentage, percentile in zip(percentages, cold_startup_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + print("Compilation time percentiles:") + for percentage, percentile in zip(percentages, cold_compilation_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + + # Warm startup statistics + print("\nWARM STARTUP:") + print(f"Avg total startup time: {avg_warm_startup:.2f} seconds") + print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds") + print("Startup time percentiles:") + for percentage, percentile in zip(percentages, warm_startup_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + print("Compilation time percentiles:") + for percentage, percentile in zip(percentages, warm_compilation_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + + print("=" * 60) + + # Output JSON results if specified + if args.output_json: + results = { + "avg_cold_startup_time": float(avg_cold_startup), + "avg_cold_compilation_time": float(avg_cold_compilation), + "cold_startup_times": cold_startup_times, + "cold_compilation_times": cold_compilation_times, + "cold_startup_percentiles": dict( + zip(percentages, cold_startup_percentiles.tolist()) + ), + "cold_compilation_percentiles": dict( + zip(percentages, cold_compilation_percentiles.tolist()) + ), + "avg_warm_startup_time": float(avg_warm_startup), + "avg_warm_compilation_time": float(avg_warm_compilation), + "warm_startup_times": warm_startup_times, + "warm_compilation_times": warm_compilation_times, + "warm_startup_percentiles": dict( + zip(percentages, warm_startup_percentiles.tolist()) + ), + "warm_compilation_percentiles": dict( + zip(percentages, warm_compilation_percentiles.tolist()) + ), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index d824e982b7489..37b8952a350b4 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -346,7 +346,10 @@ def get_requests(args, tokenizer): "output_len": args.output_len, } - if args.dataset_path is None or args.dataset_name == "random": + if args.dataset_name == "random" or ( + args.dataset_path is None + and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"} + ): sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["prefix_len"] = args.prefix_len dataset_cls = RandomDataset diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8fcd2b42e13bb..a1eec7d74483f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -463,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # the tag for the part of model being compiled, # e.g. backbone/eagle_head model_tag: str = "backbone" +model_is_encoder: bool = False @contextmanager -def set_model_tag(tag: str): +def set_model_tag(tag: str, is_encoder: bool = False): """Context manager to set the model tag.""" global model_tag + global model_is_encoder assert tag != model_tag, ( f"Model tag {tag} is the same as the current tag {model_tag}." ) old_tag = model_tag + old_is_encoder = model_is_encoder + model_tag = tag + model_is_encoder = is_encoder try: yield finally: model_tag = old_tag + model_is_encoder = old_is_encoder class VllmBackend: @@ -523,6 +529,9 @@ class VllmBackend: # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag + # Mark compilation for encoder. + self.is_encoder = model_is_encoder + # Passes to run on the graph post-grad. self.pass_manager = resolve_obj_by_qualname( current_platform.get_pass_manager_cls() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 31f5e78408460..d1ee995ee8959 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import supports_dynamo +from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo from .monitor import start_monitoring_torch_compile @@ -316,7 +316,13 @@ def _support_torch_compile( def _mark_dynamic_inputs(mod, type, *args, **kwargs): def mark_dynamic(arg, dims): if type == DynamicShapesType.UNBACKED: - torch._dynamo.decorators.mark_unbacked(arg, dims) + if is_torch_equal_or_newer("2.10.0.dev"): + for dim in dims: + torch._dynamo.decorators.mark_unbacked( + arg, dim, hint_override=arg.size()[dim] + ) + else: + torch._dynamo.decorators.mark_unbacked(arg, dims) else: torch._dynamo.mark_dynamic(arg, dims) @@ -350,7 +356,13 @@ def _support_torch_compile( if isinstance(arg, torch.Tensor): # In case dims is specified with negative indexing dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.decorators.mark_unbacked(arg, dims) + if is_torch_equal_or_newer("2.10.0.dev"): + for dim in dims: + torch._dynamo.decorators.mark_unbacked( + arg, dim, hint_override=arg.size()[dim] + ) + else: + torch._dynamo.decorators.mark_unbacked(arg, dims) def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -378,14 +390,6 @@ def _support_torch_compile( serialized backend artifacts), then we need to generate a new AOT compile artifact from scratch. """ - # Validate that AOT compile is not used with unbacked dynamic - # shapes. aot_compile re-allocates backed symbols post dynamo! - if ds_type == DynamicShapesType.UNBACKED: - raise ValueError( - "AOT compilation is not compatible with UNBACKED dynamic shapes. " - "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type " - "when VLLM_USE_AOT_COMPILE is enabled." - ) from .caching import compilation_config_hash_factors factors: list[str] = compilation_config_hash_factors(self.vllm_config) @@ -488,6 +492,12 @@ def _support_torch_compile( if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS: fx_config_patches["backed_size_oblivious"] = True + # Prepare inductor config patches + # assume_32bit_indexing is only available in torch 2.10.0.dev+ + inductor_config_patches = {} + if is_torch_equal_or_newer("2.10.0.dev"): + inductor_config_patches["assume_32bit_indexing"] = True + with ( patch.object( InliningInstructionTranslator, "inline_call_", patched_inline_call @@ -496,6 +506,7 @@ def _support_torch_compile( maybe_use_cudagraph_partition_wrapper(self.vllm_config), torch.fx.experimental._config.patch(**fx_config_patches), _torch27_patch_tensor_subclasses(), + torch._inductor.config.patch(**inductor_config_patches), ): if envs.VLLM_USE_AOT_COMPILE: self.aot_compiled_fn = self.aot_compile(*args, **kwargs) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index a7e6a69e64c91..d121106334cb9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kNvfp4Quant, kStaticTensorScale, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_block_fp8_supported, -) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import ( - is_deep_gemm_e8m0_used, - should_use_deepgemm_for_fp8_linear_for_nk, -) from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .matcher_utils import ( + MatcherFusedAddRMSNorm, + MatcherQuantFP8, + MatcherRMSNorm, +) from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): + def __init__( + self, + epsilon: float, + key: FusedRMSQuantKey, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): self.epsilon = epsilon self.quant_dtype = key.quant.dtype config = get_current_vllm_config() self.model_dtype = config.model_config.dtype if config.model_config else None - # groupwise FP8 linear uses col major scales if deepgemm and cutlass - using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk( - self.model_dtype, - config.model_config.hf_config.intermediate_size, - config.model_config.hf_config.hidden_size, - ) - use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported() - use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False - assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -142,7 +136,7 @@ class RMSNormQuantPattern: else MatcherFusedAddRMSNorm(epsilon) ) self.quant_matcher = MatcherQuantFP8( - key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0 + key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 ) @@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( @@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), ) self.group_shape = group_shape - super().__init__(epsilon, key) + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): @@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) - scale = self.quant_matcher.make_scale( - input, transposed=self.quant_matcher.use_col_major_scales - ) + scale = self.quant_matcher.make_scale(input, self.has_col_major_scales) at = auto_functionalized( self.FUSED_OP, result=result, @@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): scale_ub=None, residual=residual, group_size=self.group_shape[1], - is_scale_transposed=self.quant_matcher.use_col_major_scales, + is_scale_transposed=self.has_col_major_scales, ) # result, residual, scale @@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( @@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), ) self.group_shape = group_shape - super().__init__(epsilon, key) + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) def register(self, pm_pass: PatternMatcherPass): def pattern(input: torch.Tensor, weight: torch.Tensor): @@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): result = torch.empty_like(input, dtype=self.quant_dtype) scale = self.quant_matcher.make_scale( - input, transposed=self.quant_matcher.use_col_major_scales + input, transposed=self.quant_matcher.has_col_major_scales ) at = auto_functionalized( self.FUSED_OP, @@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): scale_ub=None, residual=None, group_size=self.group_shape[1], - is_scale_transposed=self.quant_matcher.use_col_major_scales, + is_scale_transposed=self.quant_matcher.has_col_major_scales, ) # result, scale @@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse fused_add_rms_norm + fp8 group quant - # Only register group quant patterns on CUDA where the C++ op exists - if current_platform.is_cuda(): - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) - ).register(self.patterns) - - FusedAddRMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) - - # Fuse rms_norm + fp8 group quant - RMSNormGroupQuantPattern( - epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) - ).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns @@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Fuse rms_norm + dynamic per-token fp8 quant RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]: + for has_col_major_scales in [True, False]: + for is_e8m0 in [True, False]: + # Fuse fused_add_rms_norm + fp8 group quant + FusedAddRMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 0c0bece9b3fda..ec9ed34f561b4 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp): self, quant_key: QuantKey, enabled: bool | None = None, - use_col_major_scales: bool = False, - use_e8m0: bool = False, + has_col_major_scales: bool = False, + is_e8m0: bool = False, ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key - self.use_col_major_scales = use_col_major_scales - self.use_e8m0 = use_e8m0 assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + assert quant_key.dtype == current_platform.fp8_dtype(), ( "Only QuantFP8 supported by" ) assert quant_key.scale2 is None - self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) + self.quant_fp8 = QuantFP8( + quant_key.scale.static, + quant_key.scale.group_shape, + column_major_scales=has_col_major_scales, + use_ue8m0=is_e8m0, + ) def forward_custom( self, @@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp): if self.quant_key.scale.group_shape.is_per_group(): assert scale is None - scale = self.make_scale(input, transposed=self.use_col_major_scales) + scale = self.make_scale(input, transposed=self.has_col_major_scales) finfo = torch.finfo(self.quant_key.dtype) fp8_min = finfo.min @@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp): eps=1e-10, fp8_min=fp8_min, fp8_max=fp8_max, - scale_ue8m0=self.use_e8m0, + scale_ue8m0=self.is_e8m0, ) return result, scale diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index a15c693767a51..58d3e2a14b22a 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -53,12 +53,7 @@ class PiecewiseBackend: self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 - # TODO: we need to generalize encoder compilation to other models - self.is_encoder_compilation = vllm_backend.prefix in [ - "Qwen2_5_VisionPatchEmbed", - "Qwen2_5_VisionPatchMerger", - "Qwen2_5_VisionBlock", - ] + self.is_encoder_compilation = vllm_backend.is_encoder self.compile_ranges = self.compilation_config.get_compile_ranges() if self.is_encoder_compilation: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3b6cb8a343608..4a98494b3c7b3 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -8,7 +8,7 @@ from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal -from pydantic import Field, TypeAdapter, field_validator +from pydantic import ConfigDict, Field, TypeAdapter, field_validator from pydantic.dataclasses import dataclass import vllm.envs as envs @@ -96,7 +96,7 @@ class CUDAGraphMode(enum.Enum): @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class PassConfig: """Configuration for custom Inductor passes. @@ -251,7 +251,7 @@ class DynamicShapesType(str, enum.Enum): @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class DynamicShapesConfig: """Configuration to control/debug torch compile dynamic shapes.""" @@ -290,7 +290,7 @@ class DynamicShapesConfig: @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class CompilationConfig: """Configuration for compilation. @@ -932,9 +932,13 @@ class CompilationConfig: self.splitting_ops = list(self._attention_ops) added_default_splitting_ops = True elif len(self.splitting_ops) == 0: - logger.warning_once( - "Using piecewise compilation with empty splitting_ops" - ) + if ( + self.cudagraph_mode == CUDAGraphMode.PIECEWISE + or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE + ): + logger.warning_once( + "Using piecewise cudagraph with empty splitting_ops" + ) if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.warning_once( "Piecewise compilation with empty splitting_ops do not" diff --git a/vllm/config/model.py b/vllm/config/model.py index 59e9689567bd2..1de9d15cf8c52 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -8,7 +8,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch -from pydantic import ConfigDict, SkipValidation, field_validator, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers.configuration_utils import ALLOWED_LAYER_TYPES @@ -109,7 +109,7 @@ class ModelConfig: """Convert the model using adapters defined in [vllm.model_executor.models.adapters][]. The most common use case is to adapt a text generation model to be used for pooling tasks.""" - tokenizer: SkipValidation[str] = None # type: ignore + tokenizer: str = Field(default=None) """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode | str = "auto" @@ -164,7 +164,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: SkipValidation[int] = None # type: ignore + max_model_len: int = Field(default=None, gt=0) """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -175,7 +175,7 @@ class ModelConfig: - 25.6k -> 25,600""" spec_target_max_model_len: int | None = None """Specify the maximum length for spec decoding draft models.""" - quantization: SkipValidation[QuantizationMethods | None] = None + quantization: QuantizationMethods | str | None = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to @@ -597,6 +597,14 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() + @field_validator("tokenizer", "max_model_len", mode="wrap") + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + """Skip validation if the value is `None` when initialisation is delayed.""" + if value is None: + return value + return handler(value) + @field_validator("tokenizer_mode", mode="after") def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: return tokenizer_mode.lower() @@ -610,10 +618,19 @@ class ModelConfig: @model_validator(mode="after") def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + """Called after __post_init__""" if not isinstance(self.tokenizer, str): - raise ValueError("tokenizer must be a string after __post_init__.") + raise ValueError( + f"tokenizer must be a string, got " + f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. " + "Please provide a valid tokenizer path or HuggingFace model ID." + ) if not isinstance(self.max_model_len, int): - raise ValueError("max_model_len must be an integer after __post_init__.") + raise ValueError( + f"max_model_len must be a positive integer, " + f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. " + "Example: max_model_len=2048" + ) return self def _get_transformers_backend_cls(self) -> str: @@ -1186,7 +1203,15 @@ class ModelConfig: // block.attention.n_heads_in_group ) - raise RuntimeError("Couldn't determine number of kv heads") + raise RuntimeError( + "Could not determine the number of key-value attention heads " + "from model configuration. " + f"Model: {self.model}, Architecture: {self.architectures}. " + "This usually indicates an unsupported model architecture or " + "missing configuration. " + "Please check if your model is supported at: " + "https://docs.vllm.ai/en/latest/models/supported_models.html" + ) if self.is_attention_free: return 0 @@ -1780,6 +1805,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ("ForTextEncoding", ("pooling", "embed")), ("EmbeddingModel", ("pooling", "embed")), ("ForSequenceClassification", ("pooling", "classify")), + ("ForTokenClassification", ("pooling", "classify")), ("ForAudioClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")), diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 1f9dd38ac9114..3fe066ec32505 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -156,6 +156,8 @@ class ParallelConfig: enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" + ubatch_size: int = 0 + """Number of ubatch size.""" dbo_decode_token_threshold: int = 32 """The threshold for dual batch overlap for batches only containing decodes. @@ -325,6 +327,14 @@ class ParallelConfig: including data parallelism.""" return self.world_size * self.data_parallel_size + @property + def use_ubatching(self) -> bool: + return self.enable_dbo or self.ubatch_size > 1 + + @property + def num_ubatches(self) -> int: + return 2 if self.enable_dbo else self.ubatch_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8da3ae538d671..8abbe8ba0103e 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -122,10 +122,12 @@ class SchedulerConfig: the default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" - disable_hybrid_kv_cache_manager: bool = False + disable_hybrid_kv_cache_manager: bool | None = None """If set to True, KV cache manager will allocate the same size of KV cache for all attention layers even if there are multiple type of attention layers like full attention and sliding window attention. + If set to None, the default value will be determined based on the environment + and starting configuration. """ async_scheduling: bool = False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5f8f916de438..0439dc52e7e6f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -870,9 +870,12 @@ class VllmConfig: f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) - if self.parallel_config.enable_dbo: + if self.parallel_config.use_ubatching: a2a_backend = self.parallel_config.all2all_backend - assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + assert a2a_backend in [ + "deepep_low_latency", + "deepep_high_throughput", + ], ( "Microbatching currently only supports the deepep_low_latency and " f"deepep_high_throughput all2all backend. {a2a_backend} is not " "supported. To fix use --all2all-backend=deepep_low_latency or " @@ -887,17 +890,48 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] - if not self.scheduler_config.disable_hybrid_kv_cache_manager: - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + # Hybrid KV cache manager (HMA) runtime rules: + # - Explicit enable (--no-disable-kv-cache-manager): error if runtime + # disables it + # - No preference: auto-disable for unsupported features (e.g. kv connector) + # - Explicit disable (--disable-kv-cache-manager): always respect it + need_disable_hybrid_kv_cache_manager = False + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + need_disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + need_disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + need_disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + need_disable_hybrid_kv_cache_manager = True + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to disable HMA, but only if the user didn't express a preference. if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. + # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. + need_disable_hybrid_kv_cache_manager = True logger.warning( "Turning off hybrid kv cache manager because " "`--kv-transfer-config` is set. This will reduce the " @@ -905,33 +939,26 @@ class VllmConfig: "or Mamba attention. If you are a developer of kv connector" ", please consider supporting hybrid kv cache manager for " "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if ( - self.model_config is not None - and self.model_config.attention_chunk_size is not None - ): - if ( - self.speculative_config is not None - and self.speculative_config.use_eagle() - ): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager + ) + elif ( + self.scheduler_config.disable_hybrid_kv_cache_manager is False + and need_disable_hybrid_kv_cache_manager + ): + raise ValueError( + "Hybrid KV cache manager was explicitly enabled but is not " + "supported in this configuration. Consider omitting the " + "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" + " automatically." + ) + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to enable HMA if not explicitly disabled by user or logic above. + self.scheduler_config.disable_hybrid_kv_cache_manager = False if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741f..7a4e81cf967de 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError( + "extra_tensors is not supported for NaiveAll2AllManager" + ) sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata assert dp_metadata is not None @@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase): router_logits = self.naive_multicast( router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel ) + return hidden_states, router_logits def combine( @@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): """ Gather hidden_states and router_logits from all dp ranks. """ @@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase): assert dp_metadata is not None sizes = dp_metadata.get_chunk_sizes_across_dp_rank() assert sizes is not None - dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - hidden_states, router_logits = dist_group.all_gatherv( - [hidden_states, router_logits], + + tensors_to_gather = [hidden_states, router_logits] + if extra_tensors is not None: + tensors_to_gather.extend(extra_tensors) + + gathered_tensors = dist_group.all_gatherv( + tensors_to_gather, dim=0, sizes=sizes, ) - return hidden_states, router_logits + + if extra_tensors is not None: + return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) + return gathered_tensors[0], gathered_tensors[1] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 3a849da70e4cb..caeff54406b59 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from typing import Any from weakref import WeakValueDictionary import torch @@ -68,7 +69,11 @@ class All2AllManagerBase: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ): + extra_tensors: list[torch.Tensor] | None = None, + ) -> Any: + # Subclasses should either: + # - implement handling for extra_tensors, or + # - raise a clear error if extra_tensors is not supported. raise NotImplementedError def set_num_sms(self, num_sms: int): diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index cd9c267beb5b5..9542498c453ec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list - def dispatch( + def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.all2all_manager.dispatch( + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, # type: ignore[call-arg] ) - return hidden_states, router_logits def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 376dad8a72ef1..55856d940f001 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -322,9 +322,6 @@ async def transfer_layer( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. - # NOTE: Currently we assume the same weights across different layers - # have the same shape. is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( num_local_experts=num_local_physical_experts, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index eb8342eb7129f..28aad71ab48f2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -7,7 +7,6 @@ from prometheus_client import Counter, Gauge, Histogram from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory -from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group from vllm.logger import init_logger PromMetric: TypeAlias = Gauge | Counter | Histogram @@ -53,8 +52,6 @@ class KVConnectorStats: class KVConnectorLogging: def __init__(self, kv_transfer_config: KVTransferConfig | None): - # This should be called on frontend process. - assert not has_kv_transfer_group() # Instantiate the connector's stats class. if kv_transfer_config and kv_transfer_config.kv_connector: self.connector_cls = KVConnectorFactory.get_connector_class( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 514b8534aaa6b..fb4b8ac391afb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash( return compat_hash +@dataclass +class RemoteMeta: + block_ids: list[int] + host: str + port: int + engine_id: str + request_id: str + + @dataclass class ReqMeta: local_block_ids: list[int] # To be used when logical block size does not match the kernel block size local_physical_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_port: int - remote_engine_id: str - remote_request_id: str tp_size: int + remote: RemoteMeta | None = None class NixlConnectorMetadata(KVConnectorMetadata): @@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_in_batch: set[ReqId] = set() self.reqs_not_processed: set[ReqId] = set() - def add_new_req( + def _add_new_req( + self, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ) -> ReqMeta: + return ReqMeta( + local_block_ids=local_block_ids, + local_physical_block_ids=local_block_ids, + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), + ) + + def add_new_req_to_save( self, request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - load_remote_cache: bool = True, - save_to_host: bool = False, ): - # save and load are mutually exclusive - assert load_remote_cache ^ save_to_host - _req = ReqMeta( - local_block_ids=local_block_ids, - local_physical_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_request_id=kv_transfer_params["remote_request_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - # P workers don't need to receive tp_size from proxy here. - tp_size=kv_transfer_params.get("tp_size", 1), + self.reqs_to_save[request_id] = self._add_new_req( + local_block_ids, kv_transfer_params ) - if save_to_host: - self.reqs_to_save[request_id] = _req - if load_remote_cache: - self.reqs_to_recv[request_id] = _req + + def add_new_req_to_recv( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + req = self._add_new_req(local_block_ids, kv_transfer_params) + req.remote = RemoteMeta( + block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + request_id=kv_transfer_params["remote_request_id"], + host=kv_transfer_params["remote_host"], + port=kv_transfer_params["remote_port"], + ) + self.reqs_to_recv[request_id] = req class NixlConnector(KVConnectorBase_V1): @@ -666,22 +683,18 @@ class NixlConnectorScheduler: # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_recv( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=True, - save_to_host=False, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_save( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=False, - save_to_host=True, ) meta.reqs_to_send = self._reqs_need_send @@ -1124,10 +1137,11 @@ class NixlConnectorWorker: # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: + assert meta.remote is not None fut = self._handshake_initiation_executor.submit( self._nixl_handshake, - meta.remote_host, - meta.remote_port, + meta.remote.host, + meta.remote.port, meta.tp_size, remote_engine_id, ) @@ -1774,6 +1788,7 @@ class NixlConnectorWorker: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" + assert meta.remote is not None if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: @@ -1781,7 +1796,7 @@ class NixlConnectorWorker: # post processing for heteroblocksize block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - meta.remote_engine_id + meta.remote.engine_id ) if ( not self.use_mla @@ -1916,17 +1931,18 @@ class NixlConnectorWorker: meta.local_physical_block_ids = self._logical_to_kernel_block_ids( meta.local_block_ids ) - meta.remote_block_ids = self._logical_to_kernel_block_ids( - meta.remote_block_ids + assert meta.remote is not None + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids ) - remote_engine_id = meta.remote_engine_id + remote_engine_id = meta.remote.engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_physical_block_ids), - len(meta.remote_block_ids), + len(meta.remote.block_ids), ) # always store metadata for failure recovery self._recving_metadata[req_id] = meta @@ -1965,17 +1981,18 @@ class NixlConnectorWorker: self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + assert meta.remote is not None logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, + meta.remote.engine_id, req_id, ) self._read_blocks( request_id=req_id, - dst_engine_id=meta.remote_engine_id, - remote_request_id=meta.remote_request_id, + dst_engine_id=meta.remote.engine_id, + remote_request_id=meta.remote.request_id, local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote_block_ids, + remote_block_ids=meta.remote.block_ids, ) def _read_blocks( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 338cb1f1814b5..f5ada5a009ec3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1007,10 +1007,17 @@ class GroupCoordinator: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if self.device_communicator is not None: - return self.device_communicator.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.device_communicator.dispatch( # type: ignore[call-arg] + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, ) else: return hidden_states, router_logits diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2867532756450..ca19e468914c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -408,6 +408,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo + ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold disable_nccl_for_dp_synchronization: bool = ( @@ -491,7 +492,7 @@ class EngineArgs: enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - disable_hybrid_kv_cache_manager: bool = ( + disable_hybrid_kv_cache_manager: bool | None = ( SchedulerConfig.disable_hybrid_kv_cache_manager ) @@ -841,6 +842,10 @@ class EngineArgs: "--all2all-backend", **parallel_kwargs["all2all_backend"] ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--ubatch-size", + **parallel_kwargs["ubatch_size"], + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -1557,6 +1562,7 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, enable_dbo=self.enable_dbo, + ubatch_size=self.ubatch_size, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index aceaa8bd45b81..ab055dfb1fb0e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast import jinja2 import jinja2.ext @@ -24,6 +24,7 @@ from openai.types.chat import ( ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam, + ChatCompletionFunctionToolParam, ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam, ) @@ -49,11 +50,20 @@ from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid +from vllm.utils.collection_utils import is_list_of from vllm.utils.func_utils import supports_kw +from vllm.utils.import_utils import LazyLoader + +if TYPE_CHECKING: + import torch + + from vllm.tokenizers.mistral import MistralTokenizer +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -260,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): reasoning: str | None """The reasoning content for interleaved thinking.""" + tools: list[ChatCompletionFunctionToolParam] | None + """The tools for developer role.""" + ChatCompletionMessageParam: TypeAlias = ( OpenAIChatCompletionMessageParam @@ -291,6 +304,9 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + tools: list[ChatCompletionFunctionToolParam] | None + """The tools for developer role.""" + # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] @@ -620,6 +636,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] _T = TypeVar("_T") +def _extract_embeds(tensors: list[torch.Tensor]): + if len(tensors) == 0: + return tensors + + if len(tensors) == 1: + tensors[0]._is_single_item = True # type: ignore + return tensors[0] # To keep backwards compatibility for single item input + + first_shape = tensors[0].shape + if all(t.shape == first_shape for t in tensors): + return torch.stack(tensors) + + return tensors + + +def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str): + embeds_key = f"{modality}_embeds" + embeds = items_by_modality[embeds_key] + + if len(embeds) == 0: + return embeds + if is_list_of(embeds, torch.Tensor): + return _extract_embeds(embeds) + if is_list_of(embeds, dict): + if not embeds: + return {} + + first_keys = set(embeds[0].keys()) + if any(set(item.keys()) != first_keys for item in embeds[1:]): + raise ValueError( + "All dictionaries in the list of embeddings must have the same keys." + ) + + return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys} + + return embeds + + class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number @@ -688,11 +742,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def all_mm_uuids(self) -> MultiModalUUIDDict | None: if not self._items_by_modality: return None - mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") + if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality: + raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_uuids = {} if "image_embeds" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: @@ -703,6 +760,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios if "video" in uuids_by_modality: mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids @abstractmethod @@ -714,29 +772,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} + items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - mm_inputs["image"] = ( - image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] - ) + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - mm_inputs["audio"] = ( - audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] - ) + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -747,38 +801,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} - items_by_modality = {} - for modality, items in self._items_by_modality.items(): - coros = [] - for item in items: - if item is not None: - coros.append(item) - else: - coros.append(asyncio.sleep(0)) - items_by_modality[modality] = await asyncio.gather(*coros) + coros_by_modality = { + modality: [item or asyncio.sleep(0) for item in items] + for modality, items in self._items_by_modality.items() + } + items_by_modality: dict[str, list[object | None]] = { + modality: await asyncio.gather(*coros) + for modality, coros in coros_by_modality.items() + } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - mm_inputs["image"] = ( - image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0] - ) + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - mm_inputs["audio"] = ( - audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0] - ) + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -1578,6 +1626,8 @@ def _parse_chat_message_content( if "name" in message and isinstance(message["name"], str): result_msg["name"] = message["name"] + if role == "developer": + result_msg["tools"] = message.get("tools", None) return result @@ -1588,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if ( - message["role"] == "assistant" - and "tool_calls" in message - and isinstance(message["tool_calls"], list) - ): - for item in message["tool_calls"]: + if message["role"] == "assistant" and "tool_calls" in message: + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list): + continue + + if len(tool_calls) == 0: + # Drop empty tool_calls to keep templates on the normal assistant path. + message.pop("tool_calls", None) + continue + + for item in tool_calls: # if arguments is None or empty string, set to {} if content := item["function"].get("arguments"): if not isinstance(content, (dict, list)): @@ -1792,7 +1847,7 @@ def apply_hf_chat_template( def apply_mistral_chat_template( - tokenizer: MistralTokenizer, + tokenizer: "MistralTokenizer", messages: list[ChatCompletionMessageParam], chat_template: str | None, tools: list[dict[str, Any]] | None, diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py index 9dff68236fe94..dc02ac563406a 100644 --- a/vllm/entrypoints/cli/__init__.py +++ b/vllm/entrypoints/cli/__init__.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand +from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand __all__: list[str] = [ "BenchmarkLatencySubcommand", "BenchmarkServingSubcommand", + "BenchmarkStartupSubcommand", "BenchmarkSweepSubcommand", "BenchmarkThroughputSubcommand", ] diff --git a/vllm/entrypoints/cli/benchmark/startup.py b/vllm/entrypoints/cli/benchmark/startup.py new file mode 100644 index 0000000000000..81eefd7c174dc --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/startup.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.benchmarks.startup import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase + + +class BenchmarkStartupSubcommand(BenchmarkSubcommandBase): + """The `startup` subcommand for `vllm bench`.""" + + name = "startup" + help = "Benchmark the startup time of vLLM models." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index c70eaaa082fe5..b076b883b4d93 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -2,11 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib +import copy import json import logging from abc import ABC, abstractmethod from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import replace from typing import TYPE_CHECKING, Union from openai.types.responses.response_function_tool_call_output_item import ( @@ -34,13 +36,13 @@ from vllm.entrypoints.openai.protocol import ( ResponseRawMessageAndToken, ResponsesRequest, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.entrypoints.responses_utils import construct_tool_dicts from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.tokenizers.protocol import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -74,24 +76,24 @@ class TurnMetrics: def __init__( self, - input_tokens=0, - output_tokens=0, - cached_input_tokens=0, - tool_output_tokens=0, - ): + input_tokens: int = 0, + output_tokens: int = 0, + cached_input_tokens: int = 0, + tool_output_tokens: int = 0, + ) -> None: self.input_tokens = input_tokens self.output_tokens = output_tokens self.cached_input_tokens = cached_input_tokens self.tool_output_tokens = tool_output_tokens - def reset(self): + def reset(self) -> None: """Reset counters for a new turn.""" self.input_tokens = 0 self.output_tokens = 0 self.cached_input_tokens = 0 self.tool_output_tokens = 0 - def copy(self): + def copy(self) -> "TurnMetrics": """Create a copy of this turn's token counts.""" return TurnMetrics( self.input_tokens, @@ -164,6 +166,12 @@ class SimpleContext(ConversationContext): def __init__(self): self.last_output = None + + # Accumulated final output for streaming mode + self._accumulated_text: str = "" + self._accumulated_token_ids: list[int] = [] + self._accumulated_logprobs: list = [] + self.num_prompt_tokens = 0 self.num_output_tokens = 0 self.num_cached_tokens = 0 @@ -183,6 +191,13 @@ class SimpleContext(ConversationContext): self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + # Accumulate text, token_ids, and logprobs for streaming mode + delta_output = output.outputs[0] + self._accumulated_text += delta_output.text + self._accumulated_token_ids.extend(delta_output.token_ids) + if delta_output.logprobs is not None: + self._accumulated_logprobs.extend(delta_output.logprobs) + if len(self.input_messages) == 0: output_prompt = output.prompt or "" output_prompt_token_ids = output.prompt_token_ids or [] @@ -194,11 +209,26 @@ class SimpleContext(ConversationContext): ) self.output_messages.append( ResponseRawMessageAndToken( - message=output.outputs[0].text, - tokens=output.outputs[0].token_ids, + message=delta_output.text, + tokens=delta_output.token_ids, ) ) + @property + def final_output(self) -> RequestOutput | None: + """Return the final output, with complete text/token_ids/logprobs.""" + if self.last_output is not None and self.last_output.outputs: + assert isinstance(self.last_output, RequestOutput) + final_output = copy.copy(self.last_output) + # copy inner item to avoid modify last_output + final_output.outputs = [replace(item) for item in self.last_output.outputs] + final_output.outputs[0].text = self._accumulated_text + final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids) + if self._accumulated_logprobs: + final_output.outputs[0].logprobs = self._accumulated_logprobs + return final_output + return self.last_output + def append_tool_output(self, output) -> None: raise NotImplementedError("Should not be called.") @@ -267,12 +297,40 @@ class ParsableContext(ConversationContext): self.chat_template = chat_template self.chat_template_content_format = chat_template_content_format + self.input_messages: list[ResponseRawMessageAndToken] = [] + self.output_messages: list[ResponseRawMessageAndToken] = [] + def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) self.parser.process(output.outputs[0]) + # only store if enable_response_messages is True, save memory + if self.request.enable_response_messages: + output_prompt = output.prompt or "" + output_prompt_token_ids = output.prompt_token_ids or [] + if len(self.input_messages) == 0: + self.input_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + else: + self.output_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + self.output_messages.append( + ResponseRawMessageAndToken( + message=output.outputs[0].text, + tokens=output.outputs[0].token_ids, + ) + ) + def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None: self.parser.response_messages.extend(output) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6440b702f4fa6..2768e267f4837 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,6 +18,7 @@ from vllm.beam_search import ( create_sort_beams_key_function, ) from vllm.config import ( + AttentionConfig, CompilationConfig, PoolerConfig, ProfilerConfig, @@ -72,7 +73,8 @@ from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter @@ -174,6 +176,10 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the mode of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. + attention_config: Configuration for attention mechanisms. Can be a + dictionary or an AttentionConfig instance. If a dictionary, it will + be converted to an AttentionConfig. Allows specifying the attention + backend and other attention-related settings. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: @@ -212,6 +218,7 @@ class LLM: | StructuredOutputsConfig | None = None, profiler_config: dict[str, Any] | ProfilerConfig | None = None, + attention_config: dict[str, Any] | AttentionConfig | None = None, kv_cache_memory_bytes: int | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None, @@ -251,51 +258,28 @@ class LLM: if hf_overrides is None: hf_overrides = {} - if compilation_config is not None: - if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig( - mode=CompilationMode(compilation_config) - ) - elif isinstance(compilation_config, dict): - compilation_config_instance = CompilationConfig( - **{ - k: v - for k, v in compilation_config.items() - if is_init_field(CompilationConfig, k) - } - ) - else: - compilation_config_instance = compilation_config - else: - compilation_config_instance = CompilationConfig() + def _make_config(value: Any, cls: type[_R]) -> _R: + """Convert dict/None/instance to a config instance.""" + if value is None: + return cls() + if isinstance(value, dict): + return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type] + return value - if structured_outputs_config is not None: - if isinstance(structured_outputs_config, dict): - structured_outputs_instance = StructuredOutputsConfig( - **{ - k: v - for k, v in structured_outputs_config.items() - if is_init_field(StructuredOutputsConfig, k) - } - ) - else: - structured_outputs_instance = structured_outputs_config + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) else: - structured_outputs_instance = StructuredOutputsConfig() + compilation_config_instance = _make_config( + compilation_config, CompilationConfig + ) - if profiler_config is not None: - if isinstance(profiler_config, dict): - profiler_config_instance = ProfilerConfig( - **{ - k: v - for k, v in profiler_config.items() - if is_init_field(ProfilerConfig, k) - } - ) - else: - profiler_config_instance = profiler_config - else: - profiler_config_instance = ProfilerConfig() + structured_outputs_instance = _make_config( + structured_outputs_config, StructuredOutputsConfig + ) + profiler_config_instance = _make_config(profiler_config, ProfilerConfig) + attention_config_instance = _make_config(attention_config, AttentionConfig) # warn about single-process data parallel usage. _dp_size = int(kwargs.get("data_parallel_size", 1)) @@ -340,6 +324,7 @@ class LLM: pooler_config=pooler_config, structured_outputs_config=structured_outputs_instance, profiler_config=profiler_config_instance, + attention_config=attention_config_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7be601d824f34..5d0eacae34dd7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -72,7 +72,6 @@ from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, ) -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding @@ -95,6 +94,7 @@ from vllm.entrypoints.utils import ( from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.tasks import POOLING_TASKS +from vllm.tool_parsers import ToolParserManager from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.gc_utils import freeze_gc_heap diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index b798b05dcfcbf..a8eef76cd8ae4 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -27,8 +27,8 @@ from vllm.entrypoints.constants import ( H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, ) from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger +from vllm.tool_parsers import ToolParserManager from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/parser/responses_parser.py b/vllm/entrypoints/openai/parser/responses_parser.py index 00045a7ccfd24..c364d6d80544d 100644 --- a/vllm/entrypoints/openai/parser/responses_parser.py +++ b/vllm/entrypoints/openai/parser/responses_parser.py @@ -3,7 +3,11 @@ import logging from collections.abc import Callable -from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) +from openai.types.responses.response_output_item import McpCall from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_reasoning_item import ( @@ -11,11 +15,12 @@ from openai.types.responses.response_reasoning_item import ( ResponseReasoningItem, ) +from vllm.entrypoints.constants import MCP_PREFIX from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.outputs import CompletionOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.tokenizers.protocol import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -111,6 +116,37 @@ class ResponsesParser: return self + def make_response_output_items_from_parsable_context( + self, + ) -> list[ResponseOutputItem]: + """Given a list of sentences, construct ResponseOutput Items.""" + response_messages = self.response_messages[self.num_init_messages :] + output_messages: list[ResponseOutputItem] = [] + for message in response_messages: + if not isinstance(message, ResponseFunctionToolCallOutputItem): + output_messages.append(message) + else: + if len(output_messages) == 0: + raise ValueError( + "Cannot have a FunctionToolCallOutput before FunctionToolCall." + ) + if isinstance(output_messages[-1], ResponseFunctionToolCall): + mcp_message = McpCall( + id=f"{MCP_PREFIX}{random_uuid()}", + arguments=output_messages[-1].arguments, + name=output_messages[-1].name, + server_label=output_messages[ + -1 + ].name, # TODO: store the server label + type="mcp_call", + status="completed", + output=message.output, + # TODO: support error output + ) + output_messages[-1] = mcp_message + + return output_messages + def get_responses_parser_for_simple_context( *, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index aeff6bded7f00..94dde4564ea0c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel): max_tool_calls: int | None = None metadata: Metadata | None = None model: str | None = None + logit_bias: dict[str, float] | None = None parallel_tool_calls: bool | None = True previous_response_id: str | None = None prompt: ResponsePrompt | None = None @@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel): tools: list[Tool] = Field(default_factory=list) top_logprobs: int | None = 0 top_p: float | None = None + top_k: int | None = None truncation: Literal["auto", "disabled"] | None = "disabled" user: str | None = None @@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel): _DEFAULT_SAMPLING_PARAMS = { "temperature": 1.0, "top_p": 1.0, + "top_k": 0, } def to_sampling_params( @@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel): top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] ) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output @@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel): return SamplingParams.from_optional( temperature=temperature, top_p=top_p, + top_k=top_k, max_tokens=max_tokens, logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, @@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel): RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY ), structured_outputs=structured_outputs, + logit_bias=self.logit_bias, ) def is_include_output_logprobs(self) -> bool: @@ -2045,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel): presence_penalty: float | None = 0.0 """The presence penalty to use for sampling.""" + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:transcription-sampling-params] # Default sampling parameters for transcription requests. @@ -2291,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: bool | None = False stream_continuous_usage_stats: bool | None = False + + max_completion_tokens: int | None = None + """The maximum number of tokens to generate.""" # --8<-- [end:translation-extra-params] # Default sampling parameters for translation requests. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d94fa7dd91937..98fc7810faf96 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -57,11 +57,9 @@ from vllm.entrypoints.openai.serving_engine import ( clamp_prompt_logprobs, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput @@ -73,6 +71,8 @@ from vllm.tokenizers.mistral import ( truncate_tool_call_ids, validate_request_params, ) +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.utils.collection_utils import as_list from vllm.v1.sample.logits_processor import validate_logits_processors_parameters @@ -234,11 +234,7 @@ class OpenAIServingChat(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( + conversation, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, @@ -254,11 +250,7 @@ class OpenAIServingChat(OpenAIServing): ) else: # For GPT-OSS. - ( - conversation, - request_prompts, - engine_prompts, - ) = self._make_request_with_harmony(request) + conversation, engine_prompts = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") @@ -278,7 +270,7 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) + prompt_text, _, _ = self._get_prompt_components(engine_prompt) # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. sub_request_id = ( @@ -313,7 +305,7 @@ class OpenAIServingChat(OpenAIServing): self._log_inputs( sub_request_id, - request_prompts[i], + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -537,7 +529,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: created_time = int(time.time()) @@ -591,6 +583,11 @@ class OpenAIServingChat(OpenAIServing): try: if self.reasoning_parser: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + reasoning_parser = self.reasoning_parser( tokenizer, chat_template_kwargs=request.chat_template_kwargs, # type: ignore @@ -604,6 +601,11 @@ class OpenAIServingChat(OpenAIServing): # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + tool_parsers: list[ToolParser | None] = [ self.tool_parser(tokenizer) ] * num_choices @@ -962,21 +964,9 @@ class OpenAIServingChat(OpenAIServing): assert reasoning_end_arr is not None output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: - delta_message = ( - reasoning_parser.extract_reasoning_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output_token_ids, - ) - ) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. - # Remove the text and token ids related - # to 'reasoning'. if ( res.prompt_token_ids and reasoning_parser.is_reasoning_end( @@ -985,30 +975,38 @@ class OpenAIServingChat(OpenAIServing): ): reasoning_end_arr[i] = True current_token_ids = output_token_ids - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" - # When encountering think end id in delta_token_ids, - # set reasoning status to end. - # Remove the text and token ids related - # to 'reasoning'. - if reasoning_parser.is_reasoning_end(output_token_ids): - reasoning_end_arr[i] = True - current_token_ids = ( - reasoning_parser.extract_content_ids( - output_token_ids + # Don't update current_text, keep it as is from delta + else: + delta_message = ( + reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, ) ) - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" + + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + current_token_ids = ( + reasoning_parser.extract_content_ids( + output_token_ids + ) + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" # handle tool calls only after reasoning is done, - else: + if reasoning_end_arr[i]: delta_token_ids = output_token_ids # First time to tool call, # add the remaining text and token ids @@ -1317,7 +1315,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, request_metadata: RequestResponseMetadata, ) -> ErrorResponse | ChatCompletionResponse: created_time = int(time.time()) @@ -1367,6 +1365,11 @@ class OpenAIServingChat(OpenAIServing): reasoning = None if self.tool_parser is not None: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + tool_parser = self.tool_parser(tokenizer) # NOTE: We use token_ids for openai tool parser tool_call_info = tool_parser.extract_tool_calls( @@ -1409,6 +1412,11 @@ class OpenAIServingChat(OpenAIServing): if self.reasoning_parser: try: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + reasoning_parser = self.reasoning_parser( tokenizer, chat_template_kwargs=request.chat_template_kwargs, # type: ignore @@ -1648,7 +1656,7 @@ class OpenAIServingChat(OpenAIServing): self, logprobs: dict[int, Logprob], top_logprobs: int | None, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, should_return_as_token_id: bool, ) -> list[ChatCompletionLogProb]: return [ @@ -1672,7 +1680,7 @@ class OpenAIServingChat(OpenAIServing): self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[dict[int, Logprob] | None], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, num_output_top_logprobs: int | None = None, return_as_token_id: bool | None = None, ) -> ChatCompletionLogProbs: @@ -1690,6 +1698,11 @@ class OpenAIServingChat(OpenAIServing): if should_return_as_token_id: token = f"token_id:{token_id}" else: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + token = tokenizer.decode(token_id) logprobs_content.append( @@ -1800,10 +1813,10 @@ class OpenAIServingChat(OpenAIServing): # Render prompt token ids. prompt_token_ids = render_for_completion(messages) - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) # Add cache_salt if provided in the request if request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return messages, [prompt_token_ids], [engine_prompt] + return messages, [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a799432baeb40..5f7cfaa53ec18 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,29 +5,60 @@ import json import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence +from collections.abc import AsyncGenerator, Callable, Iterable, Mapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, TypeAlias, TypeVar import numpy as np -import torch from fastapi import Request +from openai.types.responses import ( + ToolChoiceFunction, +) from pydantic import ConfigDict, TypeAdapter from starlette.datastructures import Headers -from typing_extensions import TypeIs +import vllm.envs as envs +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format, +) from vllm.entrypoints.context import ( + ConversationContext, HarmonyContext, ParsableContext, StreamingHarmonyContext, ) +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + ErrorInfo, + ErrorResponse, FunctionCall, + FunctionDefinition, ResponseInputOutputItem, ResponsesRequest, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, ) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, ClassificationCompletionRequest, @@ -49,58 +80,13 @@ from vllm.entrypoints.pooling.score.protocol import ( ScoreRequest, ScoreResponse, ) -from vllm.transformers_utils.tokenizer import AnyTokenizer - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - -from openai.types.responses import ( - ToolChoiceFunction, -) - -import vllm.envs as envs -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ( - ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format, -) -from vllm.entrypoints.context import ConversationContext -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, - ChatCompletionRequest, - ChatCompletionResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - ErrorInfo, - ErrorResponse, - FunctionDefinition, - TokenizeChatRequest, - TokenizeCompletionRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest, -) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.responses_utils import ( construct_input_messages, ) from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size -from vllm.inputs.data import PromptType -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import ( PromptComponents, get_prompt_components, @@ -109,15 +95,15 @@ from vllm.inputs.parse import ( from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest -from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict, - MultiModalUUIDDict, -) +from vllm.multimodal import MultiModalDataDict from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -183,34 +169,6 @@ AnyResponse: TypeAlias = ( ) -class TextTokensPrompt(TypedDict): - prompt: str - prompt_token_ids: list[int] - - -class EmbedsPrompt(TypedDict): - prompt_embeds: torch.Tensor - - -RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt - - -def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt - ) - - -def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt - ) - - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -221,8 +179,7 @@ class RequestProcessingMixin: handling prompt preparation and engine input. """ - request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list) - engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list) + engine_prompts: list[TokensPrompt] | None = field(default_factory=list) @dataclass(kw_only=True) @@ -423,7 +380,7 @@ class OpenAIServing: prompts_batch, lora_req_batch = zip( *[ ( - EngineTokensPrompt( + TokensPrompt( prompt_token_ids=beam.tokens, multi_modal_data=beam.multi_modal_data, mm_processor_kwargs=beam.mm_processor_kwargs, @@ -945,7 +902,7 @@ class OpenAIServing: prompt: str, tokenizer: TokenizerLike, add_special_tokens: bool, - ) -> TextTokensPrompt: + ) -> TokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) if ( @@ -986,7 +943,7 @@ class OpenAIServing: request: AnyRequest, prompt_ids: list[int], tokenizer: TokenizerLike | None, - ) -> TextTokensPrompt: + ) -> TokensPrompt: truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) if truncate_prompt_tokens is None: @@ -1009,7 +966,7 @@ class OpenAIServing: request: AnyRequest, input_ids: list[int], input_text: str, - ) -> TextTokensPrompt: + ) -> TokensPrompt: token_num = len(input_ids) # Note: EmbeddingRequest, ClassificationRequest, @@ -1040,7 +997,7 @@ class OpenAIServing: f"{token_num} tokens in the input for {operation}. " f"Please reduce the length of the input." ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation @@ -1048,7 +1005,7 @@ class OpenAIServing: request, (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), ): - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # chat completion endpoint supports max_completion_tokens if isinstance(request, ChatCompletionRequest): @@ -1076,7 +1033,7 @@ class OpenAIServing: f" - {token_num})." ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) async def _tokenize_prompt_input_async( self, @@ -1084,7 +1041,7 @@ class OpenAIServing: tokenizer: TokenizerLike, prompt_input: str | list[int], add_special_tokens: bool = True, - ) -> TextTokensPrompt: + ) -> TokensPrompt: """ A simpler implementation that tokenizes a single prompt input. """ @@ -1103,7 +1060,7 @@ class OpenAIServing: tokenizer: TokenizerLike, prompt_inputs: Iterable[str | list[int]], add_special_tokens: bool = True, - ) -> AsyncGenerator[TextTokensPrompt, None]: + ) -> AsyncGenerator[TokensPrompt, None]: """ A simpler implementation that tokenizes multiple prompt inputs. """ @@ -1156,11 +1113,7 @@ class OpenAIServing: chat_template_kwargs: dict[str, Any] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, add_special_tokens: bool = False, - ) -> tuple[ - list[ConversationMessage], - Sequence[RequestPrompt], - list[EngineTokensPrompt], - ]: + ) -> tuple[list[ConversationMessage], list[TokensPrompt]]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -1233,9 +1186,7 @@ class OpenAIServing: "Prompt has to be a string", "when the tokenizer is not initialised", ) - prompt_inputs = TextTokensPrompt( - prompt=request_prompt, prompt_token_ids=[1] - ) + prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) elif isinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, @@ -1248,14 +1199,15 @@ class OpenAIServing: assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids" ) - prompt_inputs = TextTokensPrompt( + prompt_inputs = TokensPrompt( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt, ) - engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"] - ) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if "prompt" in prompt_inputs: + engine_prompt["prompt"] = prompt_inputs["prompt"] + if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data @@ -1268,7 +1220,7 @@ class OpenAIServing: if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return conversation, [request_prompt], [engine_prompt] + return conversation, [engine_prompt] async def _process_inputs( self, @@ -1300,7 +1252,7 @@ class OpenAIServing: async def _render_next_turn( self, request: ResponsesRequest, - tokenizer: AnyTokenizer, + tokenizer: TokenizerLike | None, messages: list[ResponseInputOutputItem], tool_dicts: list[dict[str, Any]] | None, tool_parser, @@ -1311,7 +1263,7 @@ class OpenAIServing: request_input=messages, ) - _, request_prompts, engine_prompts = await self._preprocess_chat( + _, engine_prompts = await self._preprocess_chat( request, tokenizer, new_messages, @@ -1320,20 +1272,20 @@ class OpenAIServing: chat_template=chat_template, chat_template_content_format=chat_template_content_format, ) - return request_prompts, engine_prompts + return engine_prompts async def _generate_with_builtin_tools( self, request_id: str, - request_prompt: RequestPrompt, - engine_prompt: EngineTokensPrompt, + engine_prompt: TokensPrompt, sampling_params: SamplingParams, context: ConversationContext, lora_request: LoRARequest | None = None, priority: int = 0, **kwargs, ): - prompt_text, _, _ = self._get_prompt_components(request_prompt) + prompt_text, _, _ = self._get_prompt_components(engine_prompt) + orig_priority = priority sub_request = 0 while True: @@ -1341,7 +1293,7 @@ class OpenAIServing: sub_request_id = f"{request_id}_{sub_request}" self._log_inputs( sub_request_id, - request_prompt, + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -1386,10 +1338,9 @@ class OpenAIServing: # Render the next prompt token ids. if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): prompt_token_ids = context.render_for_completion() - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) - request_prompt = prompt_token_ids + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) elif isinstance(context, ParsableContext): - request_prompts, engine_prompts = await self._render_next_turn( + engine_prompts = await self._render_next_turn( context.request, context.tokenizer, context.parser.response_messages, @@ -1399,8 +1350,7 @@ class OpenAIServing: context.chat_template_content_format, ) engine_prompt = engine_prompts[0] - request_prompt = request_prompts[0] - prompt_text, _, _ = self._get_prompt_components(request_prompt) + prompt_text, _, _ = self._get_prompt_components(engine_prompt) # Update the sampling params. sampling_params.max_tokens = self.max_model_len - len( @@ -1410,19 +1360,13 @@ class OpenAIServing: priority = orig_priority - 1 sub_request += 1 - def _get_prompt_components( - self, - prompt: RequestPrompt | PromptType, - ) -> PromptComponents: - if isinstance(prompt, list): - return PromptComponents(token_ids=prompt) - - return get_prompt_components(prompt) # type: ignore[arg-type] + def _get_prompt_components(self, prompt: PromptType) -> PromptComponents: + return get_prompt_components(prompt) def _log_inputs( self, request_id: str, - inputs: RequestPrompt | PromptType, + inputs: PromptType, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, ) -> None: @@ -1484,7 +1428,7 @@ class OpenAIServing: @staticmethod def _parse_tool_calls_from_content( request: ResponsesRequest | ChatCompletionRequest, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, enable_auto_tools: bool, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, content: str | None = None, @@ -1524,6 +1468,11 @@ class OpenAIServing: and enable_auto_tools and (request.tool_choice == "auto" or request.tool_choice is None) ): + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + # Automatic Tool Call Parsing try: tool_parser = tool_parser_cls(tokenizer) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 60d14337dcaaf..1f9b5704624ab 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -104,10 +104,9 @@ from vllm.entrypoints.responses_utils import ( construct_input_messages, construct_tool_dicts, extract_tool_types, - make_response_output_items_from_parsable_context, ) from vllm.entrypoints.tool_server import ToolServer -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs @@ -258,7 +257,7 @@ class OpenAIServingResponses(OpenAIServing): self.tool_server = tool_server def _validate_generator_input( - self, engine_prompt: EngineTokensPrompt + self, engine_prompt: TokensPrompt ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): @@ -353,11 +352,11 @@ class OpenAIServingResponses(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer() if self.use_harmony: - messages, request_prompts, engine_prompts = ( - self._make_request_with_harmony(request, prev_response) + messages, engine_prompts = self._make_request_with_harmony( + request, prev_response ) else: - messages, request_prompts, engine_prompts = await self._make_request( + messages, engine_prompts = await self._make_request( request, prev_response, tokenizer ) @@ -393,7 +392,7 @@ class OpenAIServingResponses(OpenAIServing): assert len(builtin_tool_list) == 0 available_tools = [] try: - for i, engine_prompt in enumerate(engine_prompts): + for engine_prompt in engine_prompts: maybe_error = self._validate_generator_input(engine_prompt) if maybe_error is not None: return maybe_error @@ -420,7 +419,7 @@ class OpenAIServingResponses(OpenAIServing): context = HarmonyContext(messages, available_tools) else: if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: - # This is an feature in development for parsing + # This is a feature in development for parsing # tokens during generation instead of at the end context = ParsableContext( response_messages=messages, @@ -449,7 +448,6 @@ class OpenAIServingResponses(OpenAIServing): ) generator = self._generate_with_builtin_tools( request_id=request.request_id, - request_prompt=request_prompts[i], engine_prompt=engine_prompt, sampling_params=sampling_params, context=context, @@ -564,7 +562,7 @@ class OpenAIServingResponses(OpenAIServing): prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_response_output=prev_response.output if prev_response else None, ) - _, request_prompts, engine_prompts = await self._preprocess_chat( + _, engine_prompts = await self._preprocess_chat( request, tokenizer, messages, @@ -573,7 +571,7 @@ class OpenAIServingResponses(OpenAIServing): chat_template=self.chat_template, chat_template_content_format=self.chat_template_content_format, ) - return messages, request_prompts, engine_prompts + return messages, engine_prompts def _make_request_with_harmony( self, @@ -586,13 +584,13 @@ class OpenAIServingResponses(OpenAIServing): ) messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) # Add cache_salt if provided in the request if request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return messages, [prompt_token_ids], [engine_prompt] + return messages, [engine_prompt] async def _initialize_tool_sessions( self, @@ -659,24 +657,19 @@ class OpenAIServingResponses(OpenAIServing): else: status = "incomplete" elif isinstance(context, ParsableContext): - response_messages = context.parser.response_messages[ - context.parser.num_init_messages : - ] - output = make_response_output_items_from_parsable_context(response_messages) + output = context.parser.make_response_output_items_from_parsable_context() - # TODO: context for non-gptoss models doesn't use messages - # so we can't get them out yet if request.enable_response_messages: - raise NotImplementedError( - "enable_response_messages is currently only supported for gpt-oss" - ) + input_messages = context.input_messages + output_messages = context.output_messages # TODO: Calculate usage. # assert final_res.prompt_token_ids is not None num_tool_output_tokens = 0 else: assert isinstance(context, SimpleContext) - final_res = context.last_output + # Use final_output which has accumulated text/token_ids/logprobs + final_res = context.final_output assert final_res is not None assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index cea9924ebbaca..df9c06adb105a 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing): try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. - default_max_tokens = self.model_config.max_model_len + # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be + # generated by respecting the extra completion tokens arg. + if request.max_completion_tokens is None: + default_max_tokens = self.model_config.max_model_len + else: + default_max_tokens = min( + self.model_config.max_model_len, request.max_completion_tokens + ) sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params ) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 7be1263e802dc..ad1b682a9ef65 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,150 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, - ToolParserManager, -) - -__all__ = ["ToolParser", "ToolParserManager"] +import warnings -""" -Register a lazy module mapping. +def __getattr__(name: str): + if name == "ToolParser": + from vllm.tool_parsers import ToolParser -Example: - ToolParserManager.register_lazy_module( - name="kimi_k2", - module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", - class_name="KimiK2ToolParser", - ) -""" + warnings.warn( + "`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to " + "`vllm.tool_parsers.ToolParser`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + return ToolParser + if name == "ToolParserManager": + from vllm.tool_parsers import ToolParserManager -_TOOL_PARSERS_TO_REGISTER = { - "deepseek_v3": ( # name - "deepseekv3_tool_parser", # filename - "DeepSeekV3ToolParser", # class_name - ), - "deepseek_v31": ( - "deepseekv31_tool_parser", - "DeepSeekV31ToolParser", - ), - "deepseek_v32": ( - "deepseekv32_tool_parser", - "DeepSeekV32ToolParser", - ), - "ernie45": ( - "ernie45_tool_parser", - "Ernie45ToolParser", - ), - "glm45": ( - "glm4_moe_tool_parser", - "Glm4MoeModelToolParser", - ), - "granite-20b-fc": ( - "granite_20b_fc_tool_parser", - "Granite20bFCToolParser", - ), - "granite": ( - "granite_tool_parser", - "GraniteToolParser", - ), - "hermes": ( - "hermes_tool_parser", - "Hermes2ProToolParser", - ), - "hunyuan_a13b": ( - "hunyuan_a13b_tool_parser", - "HunyuanA13BToolParser", - ), - "internlm": ( - "internlm2_tool_parser", - "Internlm2ToolParser", - ), - "jamba": ( - "jamba_tool_parser", - "JambaToolParser", - ), - "kimi_k2": ( - "kimi_k2_tool_parser", - "KimiK2ToolParser", - ), - "llama3_json": ( - "llama_tool_parser", - "Llama3JsonToolParser", - ), - "llama4_json": ( - "llama_tool_parser", - "Llama3JsonToolParser", - ), - "llama4_pythonic": ( - "llama4_pythonic_tool_parser", - "Llama4PythonicToolParser", - ), - "longcat": ( - "longcat_tool_parser", - "LongcatFlashToolParser", - ), - "minimax_m2": ( - "minimax_m2_tool_parser", - "MinimaxM2ToolParser", - ), - "minimax": ( - "minimax_tool_parser", - "MinimaxToolParser", - ), - "mistral": ( - "mistral_tool_parser", - "MistralToolParser", - ), - "olmo3": ( - "olmo3_tool_parser", - "Olmo3PythonicToolParser", - ), - "openai": ( - "openai_tool_parser", - "OpenAIToolParser", - ), - "phi4_mini_json": ( - "phi4mini_tool_parser", - "Phi4MiniJsonToolParser", - ), - "pythonic": ( - "pythonic_tool_parser", - "PythonicToolParser", - ), - "qwen3_coder": ( - "qwen3coder_tool_parser", - "Qwen3CoderToolParser", - ), - "qwen3_xml": ( - "qwen3xml_tool_parser", - "Qwen3XMLToolParser", - ), - "seed_oss": ( - "seed_oss_tool_parser", - "SeedOssToolParser", - ), - "step3": ( - "step3_tool_parser", - "Step3ToolParser", - ), - "xlam": ( - "xlam_tool_parser", - "xLAMToolParser", - ), - "gigachat3": ( - "gigachat3_tool_parser", - "GigaChat3ToolParser", - ), -} + warnings.warn( + "`vllm.entrypoints.openai.tool_parsers.ToolParserManager` " + "has been moved to `vllm.tool_parsers.ToolParserManager`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + return ToolParserManager -def register_lazy_tool_parsers(): - for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items(): - module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}" - ToolParserManager.register_lazy_module(name, module_path, class_name) - - -register_lazy_tool_parsers() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index d6d3825daf7bb..e166405a6f05a 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing): if ret: return ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + _, engine_prompts = await self._preprocess_chat( cast(ChatCompletionRequest, chat_request), ctx.tokenizer, messages, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index aafc354897105..f5a21208ed802 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import ( EmbeddingServeContext, OpenAIServing, ServeContext, - TextTokensPrompt, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import ( @@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingResponseData, ) from vllm.entrypoints.renderer import RenderConfig -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import ( EmbeddingRequestOutput, @@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing): renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): - ( - _, - _, - ctx.engine_prompts, - ) = await self._preprocess_chat( + _, ctx.engine_prompts = await self._preprocess_chat( ctx.request, tokenizer, ctx.request.messages, @@ -209,14 +204,13 @@ class EmbeddingMixin(OpenAIServing): async def _process_chunked_request( self, ctx: EmbeddingServeContext, - original_prompt: TextTokensPrompt, + token_ids: list[int], pooling_params, trace_headers, prompt_idx: int, ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: """Process a single prompt using chunked processing.""" generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - token_ids = original_prompt["prompt_token_ids"] # Split into chunks using max_position_embeddings max_pos_embeddings = self._get_max_position_embeddings() @@ -228,18 +222,12 @@ class EmbeddingMixin(OpenAIServing): chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" # Create engine prompt for this chunk - chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) - - # Create chunk request prompt for logging - chunk_text = "" - chunk_request_prompt = TextTokensPrompt( - prompt=chunk_text, prompt_token_ids=chunk_tokens - ) + chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens) # Log the chunk self._log_inputs( chunk_request_id, - chunk_request_prompt, + chunk_engine_prompt, params=pooling_params, lora_request=ctx.lora_request, ) @@ -263,7 +251,7 @@ class EmbeddingMixin(OpenAIServing): request, input_ids: list[int], input_text: str, - ) -> TextTokensPrompt: + ) -> TokensPrompt: """Override to support chunked processing for embedding requests.""" token_num = len(input_ids) @@ -328,23 +316,15 @@ class EmbeddingMixin(OpenAIServing): ) ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # For other request types, use the parent's implementation return super()._validate_input(request, input_ids, input_text) - def _is_text_tokens_prompt(self, prompt) -> bool: - """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" - return ( - isinstance(prompt, dict) - and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt - ) - async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: EngineTokensPrompt, + engine_prompt: TokensPrompt, pooling_params: PoolingParams, trace_headers: Mapping[str, str] | None, prompt_index: int, @@ -413,14 +393,16 @@ class EmbeddingMixin(OpenAIServing): for i, engine_prompt in enumerate(ctx.engine_prompts): # Check if this specific prompt needs chunked processing - if self._is_text_tokens_prompt(engine_prompt): - # Cast to TextTokensPrompt since we've verified - # prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) - if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings: + if "prompt_token_ids" in engine_prompt: + prompt_token_ids = engine_prompt["prompt_token_ids"] + if len(prompt_token_ids) > max_pos_embeddings: # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( - ctx, text_tokens_prompt, pooling_params, trace_headers, i + ctx, + prompt_token_ids, + pooling_params, + trace_headers, + i, ) generators.extend(chunk_generators) continue @@ -578,14 +560,13 @@ class EmbeddingMixin(OpenAIServing): # Get original prompt token IDs for this prompt original_prompt = ctx.engine_prompts[prompt_idx] - if not self._is_text_tokens_prompt(original_prompt): + if "prompt_token_ids" not in original_prompt: return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" + f"Chunked prompt {prompt_idx} does not contain " + "token IDs" ) - original_token_ids = cast(TextTokensPrompt, original_prompt)[ - "prompt_token_ids" - ] + original_token_ids = original_prompt["prompt_token_ids"] pooling_request_output = PoolingRequestOutput( request_id=aggregator["request_id"], diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 57f1a6440cf76..4e1b326806eae 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + + _, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index a22219707c357..e81bda2eec3d7 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -120,6 +120,7 @@ class RerankResult(BaseModel): class RerankUsage(BaseModel): + prompt_tokens: int total_tokens: int diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index e5a66783005a6..edbfcd03ac92c 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -38,7 +38,8 @@ from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.async_utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -501,5 +502,7 @@ class ServingScores(OpenAIServing): id=request_id, model=model_name, results=results, - usage=RerankUsage(total_tokens=num_prompt_tokens), + usage=RerankUsage( + total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens + ), ) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index f31b309b8ca48..0f89c840be80f 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -12,9 +12,7 @@ import torch from pydantic import Field from vllm.config import ModelConfig -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt -from vllm.inputs.data import TextPrompt as EngineTextPrompt -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.tokenizers import TokenizerLike from vllm.utils.async_utils import AsyncMicrobatchTokenizer @@ -97,7 +95,7 @@ class BaseRenderer(ABC): *, prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, - ) -> list[EngineTokensPrompt]: + ) -> list[TokensPrompt]: """ Convert text or token inputs into engine-ready TokensPrompt objects. @@ -115,7 +113,7 @@ class BaseRenderer(ABC): (e.g., tokenization and length handling). Returns: - list[EngineTokensPrompt]: Engine-ready token prompts. + list[TokensPrompt]: Engine-ready token prompts. Raises: ValueError: If input formats are invalid or length limits exceeded. @@ -129,7 +127,7 @@ class BaseRenderer(ABC): prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: + ) -> list[TokensPrompt | EmbedsPrompt]: """ Convert text/token and/or base64-encoded embeddings inputs into engine-ready prompt objects using a unified RenderConfig. @@ -146,7 +144,7 @@ class BaseRenderer(ABC): (e.g., tokenization and length handling). Returns: - list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + list[Union[TokensPrompt, EmbedsPrompt]]: Engine-ready prompt objects. Raises: @@ -161,31 +159,34 @@ class BaseRenderer(ABC): prompt_embeds: bytes | list[bytes], truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, cache_salt: str | None = None, - ) -> list[EngineEmbedsPrompt]: + ) -> list[EmbedsPrompt]: """Load and validate base64-encoded embeddings into prompt objects.""" if not self.model_config.enable_prompt_embeds: raise ValueError( "You must set `--enable-prompt-embeds` to input `prompt_embeds`." ) - def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: - tensor = torch.load( - io.BytesIO(pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu"), - ) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 if truncate_prompt_tokens is not None: tensor = tensor[-truncate_prompt_tokens:] - embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) + embeds_prompt = EmbedsPrompt(prompt_embeds=tensor) if cache_salt is not None: embeds_prompt["cache_salt"] = cache_salt return embeds_prompt @@ -213,7 +214,7 @@ class CompletionRenderer(BaseRenderer): *, prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, - ) -> list[EngineTokensPrompt]: + ) -> list[TokensPrompt]: """Implementation of prompt rendering for completion-style requests. Uses async tokenizer pooling for improved performance. See base class @@ -240,7 +241,7 @@ class CompletionRenderer(BaseRenderer): prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: + ) -> list[TokensPrompt | EmbedsPrompt]: """ Render text/token prompts and/or precomputed embedding prompts. At least one of `prompt_or_prompts` or `prompt_embeds` must be provided. @@ -249,7 +250,7 @@ class CompletionRenderer(BaseRenderer): if truncate_prompt_tokens == 0: return [] - rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = [] + rendered: list[TokensPrompt | EmbedsPrompt] = [] if prompt_embeds is not None: rendered.extend( @@ -281,10 +282,10 @@ class CompletionRenderer(BaseRenderer): async def _create_prompt( self, - prompt_input: EngineTextPrompt | EngineTokensPrompt, + prompt_input: TextPrompt | TokensPrompt, config: RenderConfig, truncate_prompt_tokens: int | None, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) if prompt_token_ids is not None: @@ -317,7 +318,7 @@ class CompletionRenderer(BaseRenderer): truncate_prompt_tokens: int | None, add_special_tokens: bool, cache_salt: str | None, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: """Tokenize text input asynchronously.""" async_tokenizer = self._get_async_tokenizer() @@ -350,7 +351,7 @@ class CompletionRenderer(BaseRenderer): truncate_prompt_tokens: int | None, cache_salt: str | None, needs_detokenization: bool | None = False, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: """Optionally detokenize token IDs and build a tokens prompt.""" token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) @@ -392,8 +393,8 @@ class CompletionRenderer(BaseRenderer): max_length: int | None = None, cache_salt: str | None = None, prompt: str | None = None, - ) -> EngineTokensPrompt: - """Create validated EngineTokensPrompt.""" + ) -> TokensPrompt: + """Create validated TokensPrompt.""" if max_length is not None and len(token_ids) > max_length: raise ValueError( f"This model's maximum context length is {max_length} tokens. " @@ -401,7 +402,7 @@ class CompletionRenderer(BaseRenderer): "Please reduce the length of the input messages." ) - tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + tokens_prompt = TokensPrompt(prompt_token_ids=token_ids) if cache_salt is not None: tokens_prompt["cache_salt"] = cache_salt if prompt is not None: diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py index 99080fa43cb8e..df3d0495755da 100644 --- a/vllm/entrypoints/responses_utils.py +++ b/vllm/entrypoints/responses_utils.py @@ -16,7 +16,6 @@ from openai.types.responses.response import ToolChoice from openai.types.responses.response_function_tool_call_output_item import ( ResponseFunctionToolCallOutputItem, ) -from openai.types.responses.response_output_item import McpCall from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_reasoning_item import ResponseReasoningItem from openai.types.responses.tool import Tool @@ -27,38 +26,6 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionMessageParam, ResponseInputOutputItem, ) -from vllm.utils import random_uuid - - -def make_response_output_items_from_parsable_context( - response_messages: list[ResponseInputOutputItem], -) -> list[ResponseOutputItem]: - """Given a list of sentences, construct ResponseOutput Items.""" - output_messages: list[ResponseOutputItem] = [] - for message in response_messages: - if not isinstance(message, ResponseFunctionToolCallOutputItem): - output_messages.append(message) - else: - if len(output_messages) == 0: - raise ValueError( - "Cannot have a FunctionToolCallOutput before FunctionToolCall." - ) - if isinstance(output_messages[-1], ResponseFunctionToolCall): - mcp_message = McpCall( - id=f"{MCP_PREFIX}{random_uuid()}", - arguments=output_messages[-1].arguments, - name=output_messages[-1].name, - server_label=output_messages[ - -1 - ].name, # TODO: store the server label - type=f"{MCP_PREFIX}call", - status="completed", - output=message.output, - # TODO: support error output - ) - output_messages[-1] = mcp_message - - return output_messages def construct_input_messages( diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 5c1d17156a90d..1798b174b1413 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -27,7 +27,7 @@ from vllm.entrypoints.serve.disagg.protocol import ( GenerateResponse, GenerateResponseChoice, ) -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -99,7 +99,7 @@ class ServingTokens(OpenAIServing): # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is # completed - engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids) if request.features is not None: engine_prompt["multi_modal_data"] = None @@ -115,7 +115,7 @@ class ServingTokens(OpenAIServing): self._log_inputs( request_id, - request.token_ids, + TokensPrompt(prompt_token_ids=request.token_ids), params=sampling_params, lora_request=lora_request, ) diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 979da02d14500..0b07f0b18dfd5 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig +from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike @@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + + _, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, @@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer() self._log_inputs( - request_id, request.tokens, params=None, lora_request=lora_request + request_id, + TokensPrompt(prompt_token_ids=request.tokens), + params=None, + lora_request=lora_request, ) prompt_input = await self._tokenize_prompt_input_async( diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index daeeb995bc749..f4a633c69cb0b 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index d0f2798096263..7e072a588591c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -207,7 +207,7 @@ if TYPE_CHECKING: VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" - VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False + VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: str | None = None VLLM_NVFP4_GEMM_BACKEND: str | None = None @@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # kv-cache memory usage and enable longer contexts) # TODO(lucas): Remove this flag once latency regression is resolved. "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( - int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) + int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1")) ), # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 211551be8e60b..71289277eb987 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -33,22 +33,31 @@ def parse_raw_prompts( if len(prompt) == 0: raise ValueError("please provide at least one prompt") + # case 2: array of strings if is_list_of(prompt, str): - # case 2: array of strings prompt = cast(list[str], prompt) return [TextPrompt(prompt=elem) for elem in prompt] + + # case 3: array of tokens if is_list_of(prompt, int): - # case 3: array of tokens prompt = cast(list[int], prompt) return [TokensPrompt(prompt_token_ids=prompt)] - if is_list_of(prompt, list): - prompt = cast(list[list[int]], prompt) - if len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") - if is_list_of(prompt[0], int): - # case 4: array of token arrays - return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] + # case 4: array of token arrays + if is_list_of(prompt, list): + first = prompt[0] + if not isinstance(first, list): + raise ValueError("prompt expected to be a list of lists") + + if len(first) == 0: + raise ValueError("Please provide at least one prompt") + + # strict validation: every nested list must be list[int] + if not all(is_list_of(elem, int) for elem in prompt): + raise TypeError("Nested lists must contain only integers") + + prompt = cast(list[list[int]], prompt) + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] raise TypeError( "prompt must be a string, array of strings, " diff --git a/vllm/lora/request.py b/vllm/lora/request.py index c97e435e32165..55756bdb103bd 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -14,11 +14,6 @@ class LoRARequest( """ Request for a LoRA adapter. - Note that this class should be used internally. For online - serving, it is recommended to not allow users to use this class but - instead provide another layer of abstraction to prevent users from - accessing unauthorized LoRA adapters. - lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9ef696d80712c..66250f816f459 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -38,8 +38,9 @@ class CustomOp(nn.Module): ) return super().__new__(op_cls_to_instantiate) - def __init__(self): + def __init__(self, enforce_enable: bool = False): super().__init__() + self._enforce_enable = enforce_enable self._forward_method = self.dispatch_forward() def forward(self, *args, **kwargs): @@ -84,7 +85,11 @@ class CustomOp(nn.Module): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. compilation_config = get_cached_compilation_config() - enabled = self.enabled() + + # CustomOp object can be enforce enabled, e.g., enable device-specific + # kernels in ViT models when enabling graph mode. By default, it will + # follow the compilation_config to determine whether enable itself. + enabled = self._enforce_enable or self.enabled() if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) else: diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index b14e7dad77f9a..fde0826779eb1 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -6,7 +6,7 @@ from typing import Any import torch -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -936,7 +936,7 @@ def enable_batch_invariant_mode(): # Batch invariant matmuls are no longer needed after cublas overrides if not is_torch_equal_or_newer("2.10.0.dev"): if ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(80) or current_platform.is_device_capability(89) ): @@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool: return VLLM_BATCH_INVARIANT -def override_envs_for_invariance(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND +def override_envs_for_invariance( + attention_backend: AttentionBackendEnum | None, +): supported_backends = [ - "FLASH_ATTN", # best supported backend - "FLASHINFER", - "FLASH_ATTN_MLA", - "TRITON_MLA", + AttentionBackendEnum.FLASH_ATTN, # best supported backend + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, # Not yet supported MLA backends - # "FLASHMLA", - # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance - # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 + # AttentionBackendEnum.FLASHMLA, + # AttentionBackendEnum.FLEX_ATTENTION, # IMA issue + # AttentionBackendEnum.FLASHINFER_MLA, # PR #28967 ] - if curr_attn_backend not in supported_backends: + if attention_backend not in supported_backends: + supported_names = [b.name for b in supported_backends] + backend_name = attention_backend.name if attention_backend else None error = ( "VLLM batch_invariant mode requires an attention backend in " - f"{supported_backends}, but got '{curr_attn_backend}'. " - "Please set the 'VLLM_ATTENTION_BACKEND' environment variable " - "to one of the supported backends before enabling batch_invariant." + f"{supported_names}, but got '{backend_name}'. " + "Please use --attention-backend or attention_config to set " + "one of the supported backends before enabling batch_invariant." ) raise RuntimeError(error) - if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: + if attention_backend != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." @@ -1050,10 +1053,12 @@ def override_envs_for_invariance(): os.environ["VLLM_USE_AOT_COMPILE"] = "0" -def init_batch_invariance(): +def init_batch_invariance( + attention_backend: AttentionBackendEnum | None, +): # this will hit all the csrc overrides as well if vllm_is_batch_invariant(): - override_envs_for_invariance() + override_envs_for_invariance(attention_backend) enable_batch_invariant_mode() # Disable TF32 for batch invariance - it causes non-deterministic rounding diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 53362277dae8a..15f6e3a18ed6c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 """ - return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) + return ( + is_deep_gemm_e8m0_used() + and current_platform.is_device_capability_family(100) + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 552e38a71bf98..4a0b4e82c1b39 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -460,7 +460,6 @@ def cutlass_moe_fp8( expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - parallel_config=None, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -538,7 +537,6 @@ def cutlass_moe_fp8( c_strides2=c_strides2, quant_config=quant_config, ), - parallel_config=parallel_config, ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 4a64736ed767b..5ca91768c9760 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -293,7 +293,7 @@ def deep_gemm_moe_fp8( expert_map: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, - apply_router_weight_on_input=False, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0b83a3f5c4803..b286c3bc6fc07 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -885,12 +885,11 @@ def get_moe_configs( # If no optimized configuration is available, we will use the default # configuration - logger.warning( - ( - "Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s" - ), - config_file_paths, + logger.warning_once( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s", + ", ".join(config_file_paths), + scope="local", ) return None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 8c9d8a2777d58..a46e3972ed8e3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): "implementation based on the prepare_finalize" ) + def prepare_dp_allgather_tensor( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" + raise NotImplementedError( + "Method 'prepare_dp_allgather_tensor' is not implemented in " + f"{self.__class__.__name__}." + ) + @abstractmethod def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 1947423bf4777..9c9bc2514bb4b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): prepare_finalize: FusedMoEPrepareAndFinalize, shared_experts: torch.nn.Module | None, ) -> "FusedMoEModularMethod": - parallel_config = getattr( - getattr(moe_layer, "vllm_config", None), - "parallel_config", - None, - ) return FusedMoEModularMethod( old_quant_method, FusedMoEModularKernel( @@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, getattr(moe_layer, "shared_experts_stream", None), - parallel_config=parallel_config, + moe_parallel_config=moe_layer.moe_parallel_config, ), ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7f803720d4770..b39ce415a0f83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( is_flashinfer_supporting_global_sf, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import ( aux_stream, @@ -369,7 +370,9 @@ class FusedMoE(CustomOp): # aux_stream() returns None on non-cuda-alike platforms. self.shared_experts_stream = aux_stream() if self.shared_experts_stream is not None: - logger.info_once("Enabled separate cuda stream for MoE shared_experts") + logger.info_once( + "Enabled separate cuda stream for MoE shared_experts", scope="local" + ) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1200,10 +1203,14 @@ class FusedMoE(CustomOp): if full_load: shard_dim += 1 - # Materialize GGUF UninitializedParameter + # Materialize GGUF UninitializedParameter accounting merged weights if is_gguf_weight and isinstance(param, UninitializedParameter): + # To materialize a tensor, we must have full shape including + # number of experts, making this portion to require `full_load`. + assert full_load final_shape = list(loaded_weight.shape) - if shard_id in ["w1", "w3"]: + # w1 and w3 are merged per expert. + if shard_id in {"w1", "w3"}: final_shape[1] *= 2 final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) @@ -1927,10 +1934,46 @@ class FusedMoE(CustomOp): ) with sp_ctx: + extra_tensors = None if do_naive_dispatch_combine: - hidden_states_combined, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, ) + + post_quant_allgather = ( + has_flashinfer_trtllm_fused_moe() + and self.quant_method is not None + and self.dp_size > 1 + and self.use_ep + and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) + ) + if post_quant_allgather: + hidden_states_to_dispatch, extra_tensors = ( + self.quant_method.prepare_dp_allgather_tensor( + self, hidden_states, router_logits + ) + ) + else: + hidden_states_to_dispatch = hidden_states + + dispatch_res = get_ep_group().dispatch( + hidden_states_to_dispatch, + router_logits, + self.is_sequence_parallel, + extra_tensors=extra_tensors, + ) + if extra_tensors is not None: + hidden_states_combined, router_logits, extra_tensors_combined = ( + dispatch_res + ) + hidden_states_combined = ( + hidden_states_combined, + extra_tensors_combined[0], + ) + else: + hidden_states_combined, router_logits = dispatch_res + # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. if has_separate_shared_experts and not use_shared_experts_stream: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9e75a7c08070e..484314091cb15 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,10 +10,12 @@ from typing import final import torch import vllm.envs as envs -from vllm.config import ParallelConfig, get_current_vllm_config from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, count_expert_num_tokens, @@ -681,7 +683,7 @@ class FusedMoEModularKernel(torch.nn.Module): fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, shared_experts_stream: torch.cuda.Stream | None = None, - parallel_config: ParallelConfig | None = None, + moe_parallel_config: FusedMoEParallelConfig | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize @@ -689,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module): self.shared_experts = shared_experts self.shared_experts_stream = shared_experts_stream - # cache whether this worker is using DP+EP - if parallel_config is None: - parallel_config = get_current_vllm_config().parallel_config + # prefer an explicit FusedMoEParallelConfig when available (from + # FusedMoE layers / tests). + # if not provided, assume this kernel is + # running in a non-DP+EP context + self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config self.is_dp_ep = ( - parallel_config.data_parallel_size > 1 - and parallel_config.enable_expert_parallel + moe_parallel_config is not None + and moe_parallel_config.dp_size > 1 + and moe_parallel_config.use_ep ) self._post_init_setup() diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 9aaeec4f98a61..a143347b19f2c 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE): self._shared_experts = shared_experts # Disable shared expert overlap if: - # - we are using eplb, because of correctness issues + # - we are using eplb with non-default backend, because of correctness issues # - we are using flashinfer with DP, since there nothint to gain - # - we are using marlin kjernels + # - we are using marlin kernels + backend = self.moe_parallel_config.all2all_backend self.use_overlapped = ( use_overlapped and not ( - # TODO(wentao): find the root cause and remove this condition - self.enable_eplb + (self.enable_eplb and backend != "allgather_reducescatter") or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) ) and self._shared_experts is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ad26f9318df3..c302e465aedb7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( @@ -634,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert layer.expert_map is None, ( - "Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4Nvfp4MoEMethod." - ) assert self.moe_quant_config is not None - - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, @@ -652,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, + expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, # TODO(bnell): derive these from arguments m=x.shape[0], @@ -1266,9 +1253,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, - parallel_config=getattr( - getattr(layer, "vllm_config", None), "parallel_config", None - ), ) else: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 60dde9eb57e0f..f2b66a2beb6d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -137,7 +137,7 @@ def get_fp8_moe_backend( if ( current_platform.is_cuda() and ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 @@ -148,7 +148,7 @@ def get_fp8_moe_backend( logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - if block_quant and current_platform.is_device_capability(100): + if block_quant and current_platform.is_device_capability_family(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " @@ -193,7 +193,7 @@ def get_fp8_moe_backend( # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights if ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and block_quant ): logger.info_once( @@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) - moe_quant_method = Fp8MoEMethod(self, layer) + if self.is_checkpoint_fp8_serialized: + moe_quant_method = Fp8MoEMethod(self, layer) + else: + moe_quant_method = Fp8OnlineMoEMethod(self, layer) moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return moe_quant_method elif isinstance(layer, Attention): @@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.orig_dtype = params_dtype layer.weight_block_size = None - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn + assert self.quant_config.is_checkpoint_fp8_serialized + params_dtype = torch.float8_e4m3fn + if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size @@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): f"weight quantization block_k = {block_k}." ) - # if we are doing online quantization, patch the weight - # loaded to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded - if not self.quant_config.is_checkpoint_fp8_serialized: - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # load the current weight chunk - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - layer._loaded_numel += loaded_weight.numel() - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Delete the bookkeeping - del layer._loaded_numel - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale_inv = Parameter( dg_w2_weight_scale_inv, requires_grad=False ) - - # If checkpoint is fp16, quantize in place. - elif not self.quant_config.is_checkpoint_fp8_serialized: - fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - replace_parameter( - layer, - "w13_weight_scale", - torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device, - ), - ) - for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - - if self.rocm_aiter_moe_enabled: - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight, layer.w2_weight - ) - - replace_parameter(layer, "w13_weight", shuffled_w13) - replace_parameter(layer, "w2_weight", shuffled_w2) - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. else: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase): return result +class Fp8OnlineMoEMethod(Fp8MoEMethod): + """MoE method for online FP8 quantization. + Supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(quant_config, layer) + assert not quant_config.is_checkpoint_fp8_serialized + assert quant_config.activation_scheme == "dynamic" + assert quant_config.weight_block_size is None + assert self.flashinfer_moe_backend is None + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # We are doing online quantization, patch the weight loaded + # to call `process_weights_after_loading` in a streaming fashion + # as soon as the last weight chunk is loaded. + weight_loader = extra_weight_attrs["weight_loader"] + # create a new holder to prevent modifying behavior of any other + # objects which might depend on the old one + new_extra_weight_attrs = extra_weight_attrs + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # add a counter to track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call + # from doing anything + layer._already_called_process_weights_after_loading = True + + return res + + new_extra_weight_attrs["weight_loader"] = patched_weight_loader + extra_weight_attrs = new_extra_weight_attrs + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.rocm_aiter_moe_enabled = False + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # Lazy import to avoid importing triton too early. + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + + # If checkpoint is fp16, quantize in place. + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + + # Reshuffle weights for AITER if needed. + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight, layer.w2_weight + ) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) + + # Rushuffle weights for MARLIN if needed. + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin( + layer, False, input_dtype=self.marlin_input_dtype + ) + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index bd1d399715305..20d050d387d49 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -62,7 +62,7 @@ def choose_scaled_mm_linear_kernel( continue # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. + # make sure the kernel supports the compute capability. is_supported, reason = kernel.is_supported(compute_capability) if not is_supported: failure_reasons.append(f"{kernel.__name__}: {reason}") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 18a0fe6fbbb44..d5d7e7bfaae73 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -81,6 +81,7 @@ from vllm.utils.flashinfer import ( has_flashinfer, has_flashinfer_moe, ) +from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -187,7 +188,24 @@ class ModelOptQuantConfigBase(QuantizationConfig): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if len(self.exclude_modules) > 0: - self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) + # This is a workaround for the weights remapping issue: + # https://github.com/vllm-project/vllm/issues/28072 + # Right now, the Nvidia ModelOpt library use just one wildcard pattern: + # module_path* + # It gets applied if the whole tree of modules rooted at module_path + # is not quantized. Here we replace such pattern by 2 patterns that are + # collectively equivalent to the original pattern: + # module_path + # module_path.* + new_exclude_modules = [] + for exclude in self.exclude_modules: + if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".": + new_exclude_modules.append(exclude[:-1]) + new_exclude_modules.append(exclude[:-1] + ".*") + else: + new_exclude_modules.append(exclude) + + self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) @staticmethod def get_config_filenames() -> list[str]: @@ -607,6 +625,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): Only supports pre-quantized checkpoints with FP8 weights and scales. """ + if self.flashinfer_moe_backend is not None: + self._maybe_pad_intermediate_for_flashinfer(layer) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) @@ -684,6 +705,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) register_moe_scaling_factors(layer) + def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: + """Pad intermediate size so FlashInfer kernels' alignment constraints hold. + + Some FlashInfer FP8 MoE kernels require the (gated) intermediate size + used for GEMM to be divisible by a small alignment value. When this is + not satisfied (e.g. with certain tensor-parallel sizes), we pad the + gate/up and down projection weights along the intermediate dim. + """ + if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"): + return + + # Current local intermediate size (per partition) is the K dimension of + # the down projection. + num_experts, hidden_size, intermediate = layer.w2_weight.shape + + min_alignment = 16 + padded_intermediate = round_up(intermediate, min_alignment) + + if padded_intermediate == intermediate: + return + + logger.info( + "Padding intermediate size from %d to %d for up/down projection weights.", + intermediate, + padded_intermediate, + ) + + up_mult = 2 if self.moe.is_act_and_mul else 1 + padded_gate_up_dim = up_mult * padded_intermediate + + # Pad w13 and w12 along its intermediate dimension. + w13 = layer.w13_weight.data + padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) + padded_w13[:, : w13.shape[1], :] = w13 + layer.w13_weight.data = padded_w13 + + w2 = layer.w2_weight.data + padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) + padded_w2[:, :, :intermediate] = w2 + layer.w2_weight.data = padded_w2 + + if hasattr(layer, "intermediate_size_per_partition"): + layer.intermediate_size_per_partition = padded_intermediate + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -1393,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -1411,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale elif self.use_marlin: # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) @@ -1465,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w2_blockscale_swizzled, requires_grad=False ) + def prepare_dp_allgather_tensor( + self, + layer: FusedMoE, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Optionally prepare extra tensors to carry through DP allgather/EP.""" + import flashinfer + + a1_gscale = layer.w13_input_scale_quant + hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( + hidden_states, + a1_gscale, + is_sf_swizzled_layout=False, + ) + extra_tensors: list[torch.Tensor] = [hidden_states_sf] + return hidden_states_fp4, extra_tensors + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -1519,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=layer.e_score_correction_bias, ) + # Hidden_states in select_experts is only used to extract metadata + if isinstance(x, tuple): + x_routing, _ = x + else: + x_routing = x topk_weights, topk_ids, _ = layer.select_experts( - hidden_states=x, + hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 0131a330f70d2..4bedb951a33f5 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -162,6 +165,8 @@ class MoeWNA16Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + if isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6eae4e9e66e1b..e96e87d15787d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - elif current_platform.is_device_capability(100) and has_flashinfer(): + elif current_platform.is_device_capability_family(100) and has_flashinfer(): logger.info_once( "Using FlashInfer MXFP4 BF16 backend for SM100, " "For faster performance on SM100, consider setting " @@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and not has_flashinfer(): logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 8f96222f19f20..1d410316d6299 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool: envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutedsl_grouped_gemm_nt_masked() and current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) @@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( def flashinfer_trtllm_fp4_moe( layer: torch.nn.Module, - x: torch.Tensor, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, global_num_experts: int, @@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # hidden_states is the already quantized + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Determine routing method type use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function @@ -301,18 +305,14 @@ def flashinfer_trtllm_fp4_moe( hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, @@ -364,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe( torch.bfloat16 ).view(torch.int16) - # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + # Hidden_states is the already quantized + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Call TRT-LLM FP4 block-scale MoE kernel out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( @@ -380,18 +384,14 @@ def flashinfer_trtllm_fp4_routed_moe( hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index ba3653e4b5ea7..3d6e9cda87667 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8( assert quant_config is not None # Construct modular kernel with block-scale support when requested. - parallel_config = getattr( - getattr(layer, "vllm_config", None), - "parallel_config", - None, - ) fused_experts = mk.FusedMoEModularKernel( build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale @@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8( out_dtype=hidden_states.dtype, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), - parallel_config=parallel_config, + moe_parallel_config=layer.moe_parallel_config, ) return fused_experts( @@ -290,7 +285,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: if flashinfer_moe_backend in backend_map: if ( flashinfer_moe_backend == "latency" - and not current_platform.has_device_capability(100) + and not current_platform.is_device_capability_family(100) ): logger.info_once( "Flashinfer TRTLLM MOE backend is only supported on " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e12fe61bf3d97..ea68745585160 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp: self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) - self.is_blackwell = current_platform.is_device_capability(100) + self.is_blackwell = current_platform.is_device_capability_family(100) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() # Get the correct blockscale mul and input quant operations. @@ -762,9 +762,12 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm + # platforms that use the torch.float8_e4mefnuz dtype. finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max assert out_q is None or out_q.shape == x.shape x_q = out_q diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 7a351afb3c415..e9ecf0547033d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -63,7 +63,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "split_k": 1, } opt_flags.update_opt_flags_constraints(constraints) - elif current_platform.is_device_capability(100): + elif current_platform.is_device_capability_family(100): constraints = { "is_persistent": True, "epilogue_subtile": 1, diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 4114b21168cc8..afa69324c4e2e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -7,7 +7,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_torch +from .common import ApplyRotaryEmb @CustomOp.register("rotary_embedding") @@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp): rocm_aiter_ops.is_triton_rotary_embed_enabled() ) + self.apply_rotary_emb = ApplyRotaryEmb( + is_neox_style=self.is_neox_style, + ) + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) + query_rot = ApplyRotaryEmb.forward_static( + query_rot, + cos, + sin, + is_neox_style, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing @@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) + key_rot = ApplyRotaryEmb.forward_static( + key_rot, + cos, + sin, + is_neox_style, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 13f8d15cc0f72..3e6584dbc3da0 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,19 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from collections.abc import Callable -from functools import cache from importlib.util import find_spec import torch from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.model_executor.custom_op import CustomOp from vllm.utils.torch_utils import direct_register_custom_op -if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - logger = init_logger(__name__) @@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def apply_rotary_emb_dispatch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) - else: - return apply_rotary_emb_torch(x, cos, sin, is_neox_style) - - -@cache -def dispatch_rotary_emb_function( - default: Callable[..., torch.Tensor] | None = None, -) -> Callable[..., torch.Tensor]: - if current_platform.is_cuda(): - return apply_rotary_emb - - # if torch compile is not enabled - # use rotary embedding function from flash_attn package - # otherwise use the naive pytorch embedding implementation - # is faster when torch compile is enabled. - if current_platform.is_rocm() and not torch.compiler.is_compiling(): - if find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary - - return apply_rotary - else: - logger.warning( - "flash_attn is not installed. Falling back to PyTorch " - "implementation for rotary embeddings." - ) - if default is not None: - return default - - return apply_rotary_emb_torch - - # yarn functions # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( @@ -186,3 +116,155 @@ direct_register_custom_op( mutates_args=["query", "key"], # These tensors are modified in-place fake_impl=_flashinfer_rotary_embedding_fake, ) + + +@CustomOp.register("apply_rotary_emb") +class ApplyRotaryEmb(CustomOp): + def __init__( + self, + enforce_enable: bool = False, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> None: + super().__init__(enforce_enable) + self.is_neox_style = is_neox_style + self.enable_fp32_compute = enable_fp32_compute + + self.apply_rotary_emb_flash_attn = None + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + self.apply_rotary_emb_flash_attn = apply_rotary + + @staticmethod + def forward_static( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> torch.Tensor: + """ + Args: + x: [batch_size (optional), seq_len, num_heads, head_size] + cos: [seq_len, head_size // 2] + sin: [seq_len, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style. + enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype + for higher accuracy. + """ + origin_dtype = x.dtype + if enable_fp32_compute: + x = x.float() + + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + + if is_neox_style: + output = torch.cat((o1, o2), dim=-1) + else: + output = torch.stack((o1, o2), dim=-1).flatten(-2) + + if enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_native( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + output = self.forward_static( + x, cos, sin, self.is_neox_style, self.enable_fp32_compute + ) + return output + + def forward_cuda( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary_emb() in vllm_flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = apply_rotary_emb(x, cos, sin, interleaved) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_hip( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + if self.apply_rotary_emb_flash_attn is not None: + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary() in flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = self.apply_rotary_emb_flash_attn( + x, cos, sin, interleaved=interleaved + ).type_as(x) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + else: + # Falling back to PyTorch native implementation. + output = self.forward_native(x, cos, sin) + + return output + + def extra_repr(self) -> str: + s = f"is_neox_style={self.is_neox_style}" + s += f"enable_fp32_compute={self.enable_fp32_compute}" + return s diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 749cdbe88a62e..2eda63a34ac44 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -4,7 +4,6 @@ import torch -from .common import apply_rotary_emb_dispatch from .mrope import MRotaryEmbedding @@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0592aa8f967a6..a74bf092b182b 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import torch from vllm.triton_utils import tl, triton from .base import RotaryEmbeddingBase -from .common import apply_rotary_emb_dispatch from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py index 2432273faf195..dab7aad9759a2 100644 --- a/vllm/model_executor/layers/rotary_embedding/xdrope.py +++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py @@ -4,7 +4,6 @@ import numpy as np import torch -from .common import apply_rotary_emb_dispatch from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding @@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): dtype, ) - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [4, num_tokens] (P/W/H/T positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1 + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1 + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 9ba76f312edac..504de9fe10871 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T: tokens = getattr(text_config, "classifier_from_token", None) method = getattr(text_config, "method", None) + def auto_set_score_bias(weights): + for name, weight in weights: + if name == "score.bias": + device = self.score.weight.device + dtype = self.score.weight.dtype + bias = weight.to(device).to(dtype) + self.score.bias = torch.nn.Parameter(bias) + self.score.skip_bias_add = False + else: + yield name, weight + + weights = auto_set_score_bias(weights) if tokens is None and method is None: return super().load_weights(weights) else: diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py new file mode 100644 index 0000000000000..0ca5f2c4e0a75 --- /dev/null +++ b/vllm/model_executor/models/audioflamingo3.py @@ -0,0 +1,639 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, TypeAlias + +import torch +import torch.nn as nn +from transformers import BatchFeature, PretrainedConfig +from transformers.models.audioflamingo3 import ( + AudioFlamingo3Config, + AudioFlamingo3Processor, +) +from transformers.models.qwen2_audio import Qwen2AudioEncoder + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) + +MAX_AUDIO_LEN = 10 * 60 + + +# === Audio Inputs === # +class AudioFlamingo3FeatureInputs(TensorSchema): + """ + Dimensions: + - num_chunks: Number of audio chunks (flattened) + - nmb: Number of mel bins + - num_audios: Number of original audio files + """ + + type: Literal["audio_features"] + input_features: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("num_chunks", "nmb", 3000), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("num_chunks", 3000), + ] + + chunk_counts: Annotated[ + torch.Tensor, + TensorShape("num_audios"), + ] + + +class AudioFlamingo3EmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] + + +AudioFlamingo3Inputs: TypeAlias = ( + AudioFlamingo3FeatureInputs | AudioFlamingo3EmbeddingInputs +) + + +class AudioFlamingo3Encoder(Qwen2AudioEncoder): + def __init__( + self, + config: PretrainedConfig, + ): + super().__init__(config) + self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) + # self.layer_norm is already initialized in super().__init__ + + def forward( + self, + input_features: torch.Tensor | list[torch.Tensor], + attention_mask: torch.Tensor = None, + ): + # input_features: (batch, num_mel_bins, seq_len) + if isinstance(input_features, list): + input_features = torch.stack(input_features) + + hidden_states = nn.functional.gelu(self.conv1(input_features)) + hidden_states = nn.functional.gelu(self.conv2(hidden_states)) + hidden_states = hidden_states.transpose(-1, -2) + hidden_states = ( + hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :] + ).to(hidden_states.dtype) + + for layer in self.layers: + layer_outputs = layer(hidden_states, attention_mask) + hidden_states = layer_outputs[0] + + # AvgPool (time/2) + LayerNorm + # hidden_states: (batch, seq_len, hidden_size) + hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len) + hidden_states = self.avg_pooler(hidden_states) + hidden_states = hidden_states.permute( + 0, 2, 1 + ) # (batch, seq_len/2, hidden_size) + hidden_states = self.layer_norm(hidden_states) + + return hidden_states + + def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor): + """ + Computes the output length of the convolutional layers and the output length + of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class AudioFlamingo3MultiModalProjector(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.linear_1 = nn.Linear( + config.audio_config.hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + self.act = get_act_fn(config.projector_hidden_act) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + + def forward(self, audio_features): + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(AudioFlamingo3Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs) + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + + +class AudioFlamingo3DummyInputsBuilder( + BaseDummyInputsBuilder[AudioFlamingo3ProcessingInfo] +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + return audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate + audio_len = MAX_AUDIO_LEN * sampling_rate + num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + + return { + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]): + chunk_counts = hf_inputs.get("chunk_counts") + if chunk_counts is not None: + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + feature_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + + +class AudioFlamingo3MultiModalDataParser(MultiModalDataParser): + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_audioflamingo3_field_config, + ) + return super()._parse_audio_data(data) + + +class AudioFlamingo3MultiModalProcessor( + BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo] +): + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return AudioFlamingo3MultiModalDataParser( + target_sr=feature_extractor.sampling_rate + ) + + def _call_hf_processor( + self, + prompt: str, + mm_data: dict[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + audios = mm_data.pop("audios", []) + if audios: + mm_data["audio"] = audios + + if not mm_data.get("audio", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + + # Calculate chunk counts + audio_list = mm_data.get("audio") + if not isinstance(audio_list, list): + audio_list = [audio_list] + + chunk_counts = [] + sampling_rate = feature_extractor.sampling_rate + chunk_length = feature_extractor.chunk_length + window_size = int(sampling_rate * chunk_length) + # MAX_AUDIO_LEN is 10 * 60 in HF processor. + max_windows = int(MAX_AUDIO_LEN // chunk_length) + + for audio in audio_list: + # audio is numpy array or list + n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] + + n_win = max(1, (n_samples + window_size - 1) // window_size) + if n_win > max_windows: + n_win = max_windows + chunk_counts.append(n_win) + + outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if "input_features_mask" in outputs: + outputs["feature_attention_mask"] = outputs.pop("input_features_mask") + + outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) + + return outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _audioflamingo3_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "") + audio_token_id = vocab.get(audio_token) + if audio_token_id is None: + # Fallback if not found, though it should be there + audio_token_id = processor.audio_token_id + + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") + chunk_counts = out_mm_data.get("chunk_counts") + + def get_replacement_audioflamingo3(item_idx: int): + if feature_attention_mask is not None: + if chunk_counts is not None: + counts = ( + chunk_counts.tolist() + if isinstance(chunk_counts, torch.Tensor) + else chunk_counts + ) + start_idx = sum(counts[:item_idx]) + count = counts[item_idx] + end_idx = start_idx + count + + if isinstance(feature_attention_mask, list): + mask_list = feature_attention_mask[start_idx:end_idx] + if len(mask_list) > 0 and isinstance( + mask_list[0], torch.Tensor + ): + mask = torch.stack(mask_list) + else: + mask = torch.tensor(mask_list) + else: + mask = feature_attention_mask[start_idx:end_idx] + else: + # feature_attention_mask is list[Tensor] or Tensor + if isinstance(feature_attention_mask, list): + mask = feature_attention_mask[item_idx] + else: + mask = feature_attention_mask[item_idx].unsqueeze(0) + + # mask shape: (num_chunks, 3000) + input_lengths = mask.sum(-1) + conv_lengths = (input_lengths - 1) // 2 + 1 + audio_output_lengths = (conv_lengths - 2) // 2 + 1 + num_features = audio_output_lengths.sum().item() + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + num_features = audio_embeds.shape[0] + + if num_features == 0: + raise ValueError("Audio is too short") + + audio_tokens = [audio_token_id] * int(num_features) + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_audioflamingo3, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + AudioFlamingo3MultiModalProcessor, + info=AudioFlamingo3ProcessingInfo, + dummy_inputs=AudioFlamingo3DummyInputsBuilder, +) +class AudioFlamingo3ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): + """ + AudioFlamingo3 model for conditional generation. + + This model integrates a Whisper-based audio encoder with a Qwen2 language model. + It supports multi-chunk audio processing. + """ + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model.", + connector="multi_modal_projector.", + tower_model="audio_tower.", + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.audio_tower = AudioFlamingo3Encoder( + config.audio_config, + ) + self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) + + self.quant_config = quant_config + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> AudioFlamingo3Inputs | None: + input_features = kwargs.pop("input_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) + chunk_counts = kwargs.pop("chunk_counts", None) + + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + return AudioFlamingo3EmbeddingInputs( + type="audio_embeds", audio_embeds=audio_embeds + ) + + if input_features is not None: + return AudioFlamingo3FeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask, + chunk_counts=chunk_counts, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: AudioFlamingo3Inputs + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) + + input_features = audio_input["input_features"] + feature_attention_mask = audio_input["feature_attention_mask"] + chunk_counts = audio_input.get("chunk_counts") + + if isinstance(input_features, list): + input_features = torch.cat(input_features, dim=0) + feature_attention_mask = torch.cat(feature_attention_mask, dim=0) + + if chunk_counts is None: + chunk_counts = [1] * input_features.shape[0] + elif isinstance(chunk_counts, torch.Tensor): + chunk_counts = chunk_counts.tolist() + elif ( + isinstance(chunk_counts, list) + and chunk_counts + and isinstance(chunk_counts[0], torch.Tensor) + ): + chunk_counts = [c.item() for c in chunk_counts] + + # Calculate output lengths + input_lengths = feature_attention_mask.sum(-1) + # Conv downsampling + conv_lengths = (input_lengths - 1) // 2 + 1 + # AvgPool downsampling + audio_output_lengths = (conv_lengths - 2) // 2 + 1 + + batch_size, _, max_mel_seq_len = input_features.shape + + # Calculate max_seq_len after convs (before pooling) for attention mask + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=conv_lengths.dtype, + device=conv_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand + + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.audio_tower.conv1.weight.dtype, + device=self.audio_tower.conv1.weight.device, + ) + audio_attention_mask[audio_attention_mask_] = float("-inf") + + # Forward pass + audio_features = self.audio_tower( + input_features, attention_mask=audio_attention_mask + ) + + # Project + audio_features = self.multi_modal_projector(audio_features) + + # Masking after pooling + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_output_lengths = audio_output_lengths.unsqueeze(1) + audio_features_mask = ( + torch.arange(max_audio_tokens) + .expand(num_audios, max_audio_tokens) + .to(audio_output_lengths.device) + < audio_output_lengths + ) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) + + # Split to tuple of embeddings for individual audio input. + chunk_embeddings = torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) + + grouped_embeddings = [] + current_idx = 0 + for count in chunk_counts: + audio_chunks = chunk_embeddings[current_idx : current_idx + count] + grouped_embeddings.append(torch.cat(audio_chunks, dim=0)) + current_idx += count + return tuple(grouped_embeddings) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return [] + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bagel.py b/vllm/model_executor/models/bagel.py new file mode 100644 index 0000000000000..98229c6d4ca1b --- /dev/null +++ b/vllm/model_executor/models/bagel.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +"""Inference-only BAGEL model compatible with HuggingFace weights. + +BAGEL is a unified multimodal model for image understanding and generation. +For vLLM, we focus on the image understanding (vision-to-text) capabilities. +""" + +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, TypeAlias + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.bagel import BagelProcessor +from vllm.utils.tensor_schema import TensorSchema + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .siglip import SiglipVisionModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class BagelImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + pixel_values: torch.Tensor # Shape: (bn, 3, h, w) + + +BagelImageInputs: TypeAlias = BagelImagePixelInputs + + +class BagelVisionMLP(nn.Module): + """MLP connector for vision features.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + act_layer: str = "gelu_pytorch_tanh", + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.act = get_act_fn(act_layer) + self.fc2 = RowParallelLinear( + hidden_features, + out_features, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.act(x) + x, _ = self.fc2(x) + return x + + +class PositionEmbedding(nn.Module): + """2D position embedding for vision tokens using sin-cos embeddings.""" + + def __init__(self, max_num_patch_per_side: int, hidden_size: int): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + + # Create learnable 2D position embeddings (frozen sin-cos) + pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side) + self.register_buffer( + "pos_embed", + torch.from_numpy(pos_embed).float(), + persistent=False, + ) + + @staticmethod + def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int): + """Generate 2D sin-cos position embeddings.""" + import numpy as np + + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid( + embed_dim, grid + ) + return pos_embed + + @staticmethod + def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid): + """Generate 2D sin-cos position embeddings from grid.""" + import numpy as np + + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0] + ) + emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1] + ) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + @staticmethod + def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos): + """Generate 1D sin-cos position embeddings.""" + import numpy as np + + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + position_ids: Flattened position IDs, shape (N,) where each ID + corresponds to a position in the flattened grid + Returns: + Position embeddings of shape (N, hidden_size) + """ + # Ensure position_ids are on the same device as pos_embed + position_ids = position_ids.to(self.pos_embed.device) + return self.pos_embed[position_ids] + + +class BagelProcessingInfo(BaseProcessingInfo): + """Processing information for BAGEL model.""" + + def get_hf_processor(self, **kwargs: object) -> BagelProcessor: + from vllm.transformers_utils.processor import cached_get_image_processor + + image_processor = cached_get_image_processor( + self.ctx.model_config.model, + revision=self.ctx.model_config.revision, + trust_remote_code=self.ctx.model_config.trust_remote_code, + ) + + tokenizer = self.get_tokenizer() + + return BagelProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + hf_config = self.get_hf_config() + # Calculate max tokens per image + # For BAGEL: (vit_max_num_patch_per_side) ** 2 + max_num_patches = hf_config.vit_max_num_patch_per_side**2 + return {"image": max_num_patches} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + + # Calculate number of patches + num_patches_h = image_height // patch_size + num_patches_w = image_width // patch_size + return num_patches_h * num_patches_w + + +class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]): + """Build dummy inputs for BAGEL model profiling.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + # Use a simple placeholder for each image + return "<|image_pad|>" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + hf_config = self.info.get_hf_config() + vit_config = hf_config.vit_config + + # Use the configured image size + image_size = vit_config.image_size + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=image_size, + height=image_size, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]): + """Multimodal processor for BAGEL model.""" + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptReplacement]: + """Replace image placeholders with the correct number of tokens.""" + hf_config = self.info.get_hf_config() + + # Get the tokenizer to look up the image token ID + tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.get_vocab().get("<|image_pad|>") + if image_token_id is None: + raise ValueError( + "Image token '<|image_pad|>' not found in tokenizer vocabulary" + ) + + def get_replacement_bagel(item_idx: int): + # For BAGEL, calculate number of tokens based on max patch size + num_tokens = hf_config.vit_max_num_patch_per_side**2 + # Use the image token ID from tokenizer + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_bagel, + ) + ] + + def _get_mm_fields_config( + self, + hf_inputs: Any, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return { + "pixel_values": MultiModalFieldConfig.batched("image"), + } + + +@MULTIMODAL_REGISTRY.register_processor( + BagelMultiModalProcessor, + info=BagelProcessingInfo, + dummy_inputs=BagelDummyInputsBuilder, +) +class BagelForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): + """ + BAGEL: A unified multimodal model for image understanding and generation. + + For vLLM, we focus on the image understanding (vision-to-text) capabilities. + The image generation part is not supported in vLLM. + """ + + # Weight mapping from HF to vLLM + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.": "language_model.", + "vit_model.": "vit_model.", + "connector.": "connector.", + "vit_pos_embed.": "vit_pos_embed.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + # Ensure we have a BagelConfig (check by name to handle trust_remote_code) + # When trust_remote_code=True, the config comes from transformers_modules + if type(config).__name__ != "BagelConfig": + raise ValueError( + f"Expected BagelConfig, got {type(config).__name__}. " + "Make sure the model config is properly loaded." + ) + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize language model (Qwen2) + # Pass the llm_config from BagelConfig to initialize Qwen2 properly + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.llm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + # Initialize vision model (SigLIP) if visual understanding is enabled + if config.visual_und: + # Fix vit_config: checkpoint has 26 layers (0-25) but config says 27 + # Also disable head as it's not in checkpoint + vit_config = config.vit_config + if vit_config.num_hidden_layers == 27: + logger.warning( + "Overriding vit_config.num_hidden_layers from 27 to 26 " + "to match the Bagel model checkpoint." + ) + vit_config.num_hidden_layers = 26 + if not hasattr(vit_config, "vision_use_head"): + logger.warning( + "Setting vit_config.vision_use_head to False as it is not " + "present in the Bagel model checkpoint." + ) + vit_config.vision_use_head = False + + self.vit_model = SiglipVisionModel( + config=vit_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vit_model"), + ) + + # Initialize connector (MLP) + vit_hidden_size = config.vit_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + + self.connector = BagelVisionMLP( + in_features=vit_hidden_size, + hidden_features=llm_hidden_size, + out_features=llm_hidden_size, + act_layer=config.connector_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "connector"), + ) + + # Position embedding for vision tokens + self.vit_pos_embed = PositionEmbedding( + max_num_patch_per_side=config.vit_max_num_patch_per_side, + hidden_size=llm_hidden_size, + ) + else: + self.vit_model = None + self.connector = None + self.vit_pos_embed = None + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> BagelImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + return BagelImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + ) + + def _process_image_input( + self, image_input: BagelImageInputs + ) -> tuple[torch.Tensor, ...]: + """Process image inputs through vision encoder and connector.""" + pixel_values = image_input["pixel_values"] + + # Handle potential extra batch dimension + # Expected shape: (batch_size * num_images, 3, H, W) + # But might receive: (batch_size, num_images, 3, H, W) + if pixel_values.ndim == 5: + # Flatten batch and num_images dimensions + batch_size, num_images, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape( + batch_size * num_images, channels, height, width + ) + + # Get vision features from SigLIP + # pixel_values shape: (batch_size * num_images, 3, H, W) + vision_features = self.vit_model(pixel_values) + + # Pass through connector + vision_embeds = self.connector(vision_features) + + # Add position embeddings + batch_size, num_patches, hidden_size = vision_embeds.shape + patch_size = self.config.vit_config.patch_size + image_size = self.config.vit_config.image_size + + # Calculate grid dimensions + num_patches_per_side = image_size // patch_size + + # Create flattened position IDs (0 to num_patches-1) + # For BAGEL, we use extrapolate mode by default + h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) + w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) + position_ids = ( + h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords + ).flatten() + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten() + + # Add position embeddings + pos_embeds = self.vit_pos_embed(position_ids) + pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size) + # Ensure pos_embeds are on the same device as vision_embeds + pos_embeds = pos_embeds.to(vision_embeds.device) + vision_embeds = vision_embeds + pos_embeds + + # Split by image + return tuple(vision_embeds) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + """Get multimodal embeddings from input.""" + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def get_language_model(self) -> nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + """Run forward pass for BAGEL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a batch. + positions: Flattened (concatenated) position ids corresponding to a batch. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights from checkpoint.""" + skip_prefixes = [] + # Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module + skip_prefixes.append("vit_pos_embed.pos_embed") + + # If visual understanding is disabled, skip vision-related weights + if self.vit_model is None: + skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"]) + + # Skip generation-related weights since we only support text2text and image2text + # Filter out all image generation components: + # - 'moe_gen': MoE generation weights + # - 'latent_pos_embed': Latent position embeddings for VAE + # - 'llm2vae', 'vae2llm': LLM-VAE projections + # - 'time_embedder': Timestep embeddings for diffusion + # - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder + generation_keywords = [ + "moe_gen", + "latent_pos_embed", + "llm2vae", + "vae2llm", + "time_embedder", + ] + vae_prefixes = [ + "decoder.", + "encoder.", + ] # VAE encoder/decoder, not vision encoder + filtered_weights = [] + for name, tensor in weights: + # Skip generation-related keywords + if any(skip in name for skip in generation_keywords): + continue + if any(name.startswith(prefix) for prefix in vae_prefixes): + continue + + if "patch_embedding.weight" in name and tensor.ndim == 2: + out_channels = tensor.shape[0] + in_features = tensor.shape[1] + patch_size = self.config.vit_config.patch_size + in_channels = self.config.vit_config.num_channels + if in_features == in_channels * patch_size * patch_size: + tensor = tensor.reshape( + out_channels, patch_size, patch_size, in_channels + ) + tensor = tensor.permute(0, 3, 1, 2).contiguous() + + filtered_weights.append((name, tensor)) + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e774cd647ea8c..ee429bf458843 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -55,7 +55,9 @@ class BertEmbedding(nn.Module): "position_ids", torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) if self.position_embedding_type != "absolute": raise ValueError( "Only 'absolute' position_embedding_type" + " is supported" diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 06cc92ee88180..4b08472538db4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -363,7 +363,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): else: kernel_block_alignment_size = 16 if ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and model_config.get_head_size() == 256 and ( attention_config.backend is None diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index da19d8fdb15e0..6d8dbec9236c9 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import utils as dist_utils from vllm.distributed.parallel_state import ( @@ -30,6 +29,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -159,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): return processor -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - - cos = freqs.cos() - sin = freqs.sin() - - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - - output = (tensor * cos) + (rotate_half(tensor) * sin) - - output = output.to(orig_dtype) - - return output - - class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() @@ -254,11 +230,15 @@ class DotsVisionAttention(nn.Module): bias: bool = True, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.embed_dim = dim self.tp_size = ( @@ -287,31 +267,18 @@ class DotsVisionAttention(nn.Module): prefix=f"{prefix}.proj", disable_tp=use_data_parallel, ) - # Select attention backend - self.attn_backend = get_vit_attn_backend( - self.hidden_size_per_attention_head, - torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Unsupported vision attention backend: {self.attn_backend}" - ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } def forward( self, @@ -319,7 +286,7 @@ class DotsVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor | None = None, *, - max_seqlen: int | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) @@ -333,44 +300,20 @@ class DotsVisionAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) - k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) - v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) - output = self.flash_attn_varlen_func( - q_, - k_, - v_, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - context_layer = output.view( - bs, - -1, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - s = int(cu_seqlens[i - 1]) - e = int(cu_seqlens[i]) - q_i = q[:, s:e].permute(0, 2, 1, 3) - k_i = k[:, s:e].permute(0, 2, 1, 3) - v_i = v[:, s:e].permute(0, 2, 1, 3) - out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - out_i = out_i.permute(0, 2, 1, 3) - outputs.append(out_i) - context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - else: - raise RuntimeError("Unsupported attention backend") + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) # [B,S,H,D] -> [S,B,H*D] -> [S, C] context_layer = context_layer.permute(1, 0, 2, 3).contiguous() @@ -385,14 +328,19 @@ class DotsSwiGLUFFN(nn.Module): config, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() hidden_features = config.intermediate_size in_features = config.embed_dim bias = config.use_bias + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) # Referenced aimv2.py AIMv2SwiGLUFFN self.fc13 = MergedColumnParallelLinear( in_features, @@ -498,9 +446,8 @@ class DotsVisionBlock(nn.Module): config, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -510,16 +457,15 @@ class DotsVisionBlock(nn.Module): num_heads=config.num_attention_heads, bias=config.use_bias, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.mlp = DotsSwiGLUFFN( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) @@ -546,12 +492,11 @@ class DotsVisionTransformer(nn.Module): self, config: DotsVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -561,6 +506,11 @@ class DotsVisionTransformer(nn.Module): head_dim = config.embed_dim // config.num_attention_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -578,9 +528,8 @@ class DotsVisionTransformer(nn.Module): DotsVisionBlock( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{i}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) for i in range(num_layers) ] @@ -592,6 +541,11 @@ class DotsVisionTransformer(nn.Module): else: self.post_trunk_norm = None + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.merger = PatchMerger( dim=config.hidden_size, context_dim=config.embed_dim, @@ -647,7 +601,7 @@ class DotsVisionTransformer(nn.Module): self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -733,17 +687,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA self.config.vision_config = vision_config else: vision_config = self.config.vision_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) + self.vision_tower = DotsVisionTransformer( vision_config, quant_config=self.quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_tower"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 053d260cc09b2..61cf78fdb5a67 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -33,14 +33,14 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -69,7 +72,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -89,52 +91,6 @@ logger = init_logger(__name__) # === Vision Transformer === # -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) - return output - - def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist @@ -163,8 +119,8 @@ class Ernie4_5_VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -193,33 +149,18 @@ class Ernie4_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj", ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Ernie45-VL does not support {self.attn_backend} backend now." - ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -253,58 +194,32 @@ class Ernie4_5_VisionAttention(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - - lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - q_chunks = torch.split(q, lens, dim=1) - k_chunks = torch.split(k, lens, dim=1) - v_chunks = torch.split(v, lens, dim=1) - for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + output = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -350,8 +265,8 @@ class Ernie4_5_VisionBlock(nn.Module): act_layer: type[nn.Module] = QuickGELU, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -366,8 +281,8 @@ class Ernie4_5_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - attn_backend_override=attn_backend_override, ) self.mlp = Ernie4_5_VisionMLP( @@ -383,7 +298,7 @@ class Ernie4_5_VisionBlock(nn.Module): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -441,8 +356,8 @@ class Ernie4_5_VisionTransformer(nn.Module): vision_config, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -477,8 +392,8 @@ class Ernie4_5_VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -489,6 +404,9 @@ class Ernie4_5_VisionTransformer(nn.Module): ) self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -535,13 +453,13 @@ class Ernie4_5_VisionTransformer(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: + def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -1304,17 +1222,12 @@ class Ernie4_5_VLMoeForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.vision_model = Ernie4_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_model"), - attn_backend_override=attn_backend_override, ) self.language_model = Ernie4_5_VLMoeForCausalLM( diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 786482d77a1d2..84989537da6e2 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend -from vllm.config import VllmConfig +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, +) +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils @@ -63,6 +65,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -93,7 +98,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, ) -from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision +from .qwen2_vl import _create_qwen2vl_field_factory from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -191,10 +196,15 @@ class Glm4vVisionMLP(nn.Module): hidden_features: int, bias: bool = False, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, @@ -248,12 +258,16 @@ class Glm4vVisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel else get_tensor_model_parallel_world_size() ) @@ -287,33 +301,13 @@ class Glm4vVisionAttention(nn.Module): disable_tp=use_data_parallel, ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"GLM-4V does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -338,61 +332,33 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - - lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - q_chunks = torch.split(q, lens, dim=1) - k_chunks = torch.split(k, lens, dim=1) - v_chunks = torch.split(v, lens, dim=1) - for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -406,9 +372,8 @@ class Glm4vVisionBlock(nn.Module): mlp_hidden_dim: int, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -420,17 +385,16 @@ class Glm4vVisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.mlp = Glm4vVisionMLP( dim, mlp_hidden_dim, bias=False, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -489,11 +453,16 @@ class Glm4vPatchMerger(nn.Module): d_model: int, context_dim: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, bias: bool = False, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = d_model self.proj = ColumnParallelLinear( self.hidden_size, @@ -649,19 +618,19 @@ class Glm4vVisionTransformer(nn.Module): vision_config: Glm4vVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() + assert multimodal_config is not None, "multimodal_config must be provided" + patch_size = vision_config.patch_size temporal_patch_size = vision_config.temporal_patch_size in_channels = vision_config.in_channels depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - self.use_data_parallel = use_data_parallel self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -690,9 +659,8 @@ class Glm4vVisionTransformer(nn.Module): mlp_hidden_dim=vision_config.out_hidden_size, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -701,9 +669,9 @@ class Glm4vVisionTransformer(nn.Module): d_model=vision_config.out_hidden_size, context_dim=vision_config.intermediate_size, quant_config=quant_config, + multimodal_config=multimodal_config, bias=False, prefix=f"{prefix}.merger", - use_data_parallel=self.use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) @@ -723,7 +691,7 @@ class Glm4vVisionTransformer(nn.Module): self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + attn_backend_override=multimodal_config.mm_encoder_attn_backend, ) @property @@ -775,13 +743,13 @@ class Glm4vVisionTransformer(nn.Module): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> int | None: + ) -> torch.Tensor | None: max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -1465,18 +1433,12 @@ class Glm4vForConditionalGeneration( self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) if config.model_type == "glm4v": diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index f31da0ee302b3..fcf88953ba20f 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation @@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -32,6 +30,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -61,7 +62,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -80,7 +80,6 @@ from .utils import ( is_pp_missing_parameter, maybe_prefix, ) -from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -344,20 +343,14 @@ def apply_rotary_pos_emb_flashatt( cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - elif current_platform.is_rocm(): - from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb - else: - # For other platforms, use PyTorch fallback - from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - ) + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) - apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) + q_embed = apply_rotary_emb(q, cos, sin) + k_embed = apply_rotary_emb(k, cos, sin) - q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed @@ -369,8 +362,8 @@ class KeyeSiglipAttention(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -408,34 +401,14 @@ class KeyeSiglipAttention(nn.Module): prefix=f"{prefix}.out_proj", ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_heads, head_size=self.head_dim, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + num_kv_heads=self.num_kv_heads, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Keye-VL does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def forward( self, hidden_states: torch.Tensor, @@ -450,8 +423,7 @@ class KeyeSiglipAttention(nn.Module): dim=-1, ) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - batch_size = q.shape[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() if rope_emb is None: q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) @@ -482,38 +454,14 @@ class KeyeSiglipAttention(nn.Module): self.head_dim, ) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - causal=False, - softmax_scale=self.scale, - ) - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i) - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - - context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> b s (h d)") output, _ = self.out_proj(context_layer) return output @@ -547,8 +495,8 @@ class KeyeSiglipEncoderLayer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -556,8 +504,8 @@ class KeyeSiglipEncoderLayer(nn.Module): self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -601,8 +549,8 @@ class KeyeSiglipEncoder(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -614,8 +562,8 @@ class KeyeSiglipEncoder(nn.Module): KeyeSiglipEncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", - attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -696,8 +644,8 @@ class KeyeSiglipVisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -707,8 +655,8 @@ class KeyeSiglipVisionTransformer(nn.Module): self.encoder = KeyeSiglipEncoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -779,16 +727,16 @@ class KeyeSiglipVisionModel(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = KeyeSiglipVisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -1329,16 +1277,11 @@ class BaseKeyeModule(nn.Module): self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, ) self.mlp_AR = self._build_projector( diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py index 23668cc2b746e..35a6a78f653ef 100644 --- a/vllm/model_executor/models/opencua.py +++ b/vllm/model_executor/models/opencua.py @@ -240,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ) if multimodal_config.get_limit_per_prompt("image"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = OpenCUAVisionTransformer( vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, + multimodal_config=self.multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 0ad22aab748e3..945138b5972f7 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,8 +10,7 @@ import torch import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module): config: PretrainedConfig, visual_vocab_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config self.vit = self._init_backbone( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vit", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) # reserved tokens for INDICATOR_IDS head_dim = visual_vocab_size - len(INDICATOR_IDS) @@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: QuantizationConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": return Siglip2NavitModel( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=prefix, - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @@ -468,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): prefix=maybe_prefix(prefix, "llm"), ) - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual_tokenizer = VisualTokenizer( config=config.vit_config, visual_vocab_size=config.visual_vocab_size, + multimodal_config=multimodal_config, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", - attn_backend_override=attn_backend_override, ) self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 9703a5b417d02..56565266c0dcc 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -22,8 +22,7 @@ from typing import Annotated, Literal import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature, PretrainedConfig from transformers.activations import GELUActivation from transformers.modeling_outputs import ( @@ -32,13 +31,10 @@ from transformers.modeling_outputs import ( from transformers.utils import torch_int from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.attention.ops.vit_attn_wrappers import ( - vit_flash_attn_wrapper, -) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -51,7 +47,7 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -134,47 +130,6 @@ def smart_resize( return h_bar, w_bar -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = rotary_emb_function(t_, cos, sin).type_as(t) - return output - - class PaddleOCRVLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() @@ -578,9 +533,8 @@ class SiglipAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -608,18 +562,16 @@ class SiglipAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - - self.attn_backend = attn_backend - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", + ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape @@ -662,47 +614,23 @@ class SiglipAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - if max_seqlen is None: - raise ValueError("Flash attention backend requires max_seqlen.") - context_layer = vit_flash_attn_wrapper( - q, - k, - v, - cu_seqlens, - max_seqlen, - batch_size, - self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(tensor, "b s h d -> b h s d") - for tensor in (q_i, k_i, v_i) - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() - else: - raise RuntimeError( - f"PaddleOCR-VL does not support {self.attn_backend} backend now." - ) + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> b s (h d)") output, _ = self.out_proj(context_layer) - output = rearrange(output, "s b d -> b s d") return output @@ -774,10 +702,8 @@ class SiglipEncoderLayer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - *, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -787,9 +713,8 @@ class SiglipEncoderLayer(nn.Module): num_heads=config.num_attention_heads, projection_size=config.hidden_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - attn_backend=attn_backend, - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -832,14 +757,18 @@ class SiglipEncoder(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads + + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -858,9 +787,8 @@ class SiglipEncoder(nn.Module): SiglipEncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", - attn_backend=self.attn_backend, - attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -941,8 +869,8 @@ class SiglipVisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -952,8 +880,8 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -991,16 +919,16 @@ class SiglipVisionModel(nn.Module): self, config, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = SiglipVisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -1119,17 +1047,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) - self.visual = SiglipVisionModel( config=config.vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, ) self.mlp_AR = Projector(config, config.vision_config) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index faf2d80d24bba..555e6ea4b8cb2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -59,7 +59,8 @@ from vllm.multimodal.processing import ( from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 492ba2fb12145..61a6e67805d6a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -281,6 +281,9 @@ class QWenBaseModel(nn.Module): self.transformer.make_empty_intermediate_tensors ) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.wte(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3af4a49cd77cc..f4c2d3cb75d25 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module): prefix: str = "", attn_type: str = AttentionType.DECODER, dual_chunk_attention_config: dict[str, Any] | None = None, + qk_norm: bool = False, + rms_norm_eps: float = 1e-6, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = dual_chunk_attention_config + self.qk_norm = qk_norm self.qkv_proj = QKVParallelLinear( hidden_size, @@ -162,6 +165,11 @@ class Qwen2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) + # QK Normalization support (used in BAGEL and some other models) + if self.qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.rotary_emb = get_rope( self.head_dim, max_position=max_position, @@ -197,6 +205,23 @@ class Qwen2Attention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Apply QK normalization if enabled (before RoPE) + if self.qk_norm: + # Reshape to apply per-head normalization + # q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim) + total_tokens = q.shape[0] + q = q.view(total_tokens, self.num_heads, self.head_dim) + k = k.view(total_tokens, self.num_kv_heads, self.head_dim) + + # Apply normalization + q = self.q_norm(q) + k = self.k_norm(k) + + # Reshape back + q = q.view(total_tokens, self.q_size) + k = k.view(total_tokens, self.kv_size) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -227,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module): else: attn_type = AttentionType.ENCODER_ONLY + # Check if QK normalization is enabled (used in BAGEL and some other models) + qk_norm = getattr(config, "qk_norm", False) + self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -238,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module): prefix=f"{prefix}.self_attn", attn_type=attn_type, dual_chunk_attention_config=dual_chunk_attention_config, + qk_norm=qk_norm, + rms_norm_eps=config.rms_norm_eps, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -480,6 +510,8 @@ class Qwen2Model(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 3438406c4fac1..f9bce4bf981b2 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -845,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + multimodal_config=multimodal_config, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index fba06e34f6227..b730ac0315893 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ) from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend -from vllm.attention.ops.vit_attn_wrappers import ( - vit_flash_attn_wrapper, - vit_torch_sdpa_wrapper, -) +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.forward_context import set_forward_context @@ -64,6 +60,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.vision import should_torch_compile_mm_vit @@ -99,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import ( Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision, ) from .utils import ( AutoWeightsLoader, @@ -267,10 +265,15 @@ class Qwen2_5_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] @@ -304,13 +307,16 @@ class Qwen2_5_VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel @@ -342,18 +348,14 @@ class Qwen2_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj", disable_tp=use_data_parallel, ) - self.attn_backend = attn_backend - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def forward( self, @@ -380,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module): qk_reshaped = einops.rearrange( qk, "b s two head head_dim -> (two b) s head head_dim", two=2 ) - qk_rotated = apply_rotary_pos_emb_vision( - qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_reshaped, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) qk_rotated = qk_rotated.view( 2, @@ -394,32 +398,17 @@ class Qwen2_5_VisionAttention(nn.Module): else: q, k, v = qkv.unbind(dim=2) - if self.is_flash_attn_backend: - context_layer = vit_flash_attn_wrapper( - q, - k, - v, - cu_seqlens, - max_seqlen, - batch_size, - self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - # Never remove the next contiguous logic - # Without it, hallucinations occur with the backend - if current_platform.is_rocm(): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - context_layer = vit_torch_sdpa_wrapper( - q, - k, - v, - cu_seqlens, - ) + context_layer = einops.rearrange( + context_layer, "b s h d -> s b (h d)", b=batch_size + ).contiguous() output, _ = self.proj(context_layer) return output @@ -443,10 +432,8 @@ class Qwen2_5_VisionBlock(nn.Module): act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -458,10 +445,8 @@ class Qwen2_5_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend=attn_backend, - attn_backend_override=attn_backend_override, ) self.mlp = Qwen2_5_VisionMLP( dim, @@ -469,8 +454,8 @@ class Qwen2_5_VisionBlock(nn.Module): act_fn=act_fn, bias=True, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -542,10 +527,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -586,9 +576,8 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -598,7 +587,6 @@ class Qwen2_5_VisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - self.use_data_parallel = use_data_parallel self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw @@ -612,7 +600,7 @@ class Qwen2_5_VisionTransformer(nn.Module): # DO NOT MOVE THIS IMPORT from vllm.compilation.backends import set_model_tag - with set_model_tag("Qwen2_5_VisionPatchEmbed"): + with set_model_tag("Qwen2_5_VisionPatchEmbed", is_encoder=True): self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, @@ -629,19 +617,17 @@ class Qwen2_5_VisionTransformer(nn.Module): rope_parameters={"partial_rotary_factor": 0.5}, ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, @@ -651,7 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module): f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) - with set_model_tag("Qwen2_5_VisionBlock"): + with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True): self.blocks = nn.ModuleList( [ Qwen2_5_VisionBlock( @@ -661,24 +647,22 @@ class Qwen2_5_VisionTransformer(nn.Module): act_fn=get_act_and_mul_fn(vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] ) - with set_model_tag("Qwen2_5_VisionPatchMerger"): + with set_model_tag("Qwen2_5_VisionPatchMerger", is_encoder=True): self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, ) @property @@ -1200,18 +1184,12 @@ class Qwen2_5_VLForConditionalGeneration( if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen2_5_VisionTransformer( vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4e54208a59b67..321fbd764c0f5 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -33,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor @@ -45,12 +44,10 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, -) -from vllm.config import VllmConfig +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import parallel_state +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU @@ -62,8 +59,7 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -251,10 +247,15 @@ class Qwen2VisionMLP(nn.Module): hidden_features: int, act_layer: type[nn.Module] = QuickGELU, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.fc1 = ColumnParallelLinear( in_features, hidden_features, @@ -278,16 +279,6 @@ class Qwen2VisionMLP(nn.Module): return x -def apply_rotary_pos_emb_vision( - t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function( - default=partial(apply_rotary_emb_torch, is_neox_style=True) - ) - output = rotary_emb_function(t, cos, sin).type_as(t) - return output - - class Qwen2VisionAttention(nn.Module): def __init__( self, @@ -295,12 +286,16 @@ class Qwen2VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel @@ -329,41 +324,32 @@ class Qwen2VisionAttention(nn.Module): disable_tp=use_data_parallel, ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, @@ -387,60 +373,27 @@ class Qwen2VisionAttention(nn.Module): # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform - - if current_platform.is_rocm(): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - outputs = [] - - lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - q_chunks = torch.split(q, lens, dim=1) - k_chunks = torch.split(k, lens, dim=1) - v_chunks = torch.split(v, lens, dim=1) - for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -455,9 +408,8 @@ class Qwen2VisionBlock(nn.Module): act_layer: type[nn.Module] = QuickGELU, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -471,17 +423,16 @@ class Qwen2VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.mlp = Qwen2VisionMLP( dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -541,10 +492,15 @@ class Qwen2VisionPatchMerger(nn.Module): norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -588,9 +544,8 @@ class Qwen2VisionTransformer(nn.Module): vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -604,7 +559,11 @@ class Qwen2VisionTransformer(nn.Module): num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio - self.use_data_parallel = use_data_parallel + self.use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.out_hidden_size = vision_config.hidden_size self.spatial_merge_size = spatial_merge_size @@ -636,8 +595,7 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) for layer_idx in range(depth) ] @@ -648,7 +606,10 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, + multimodal_config=multimodal_config, + ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, @@ -709,7 +670,7 @@ class Qwen2VisionTransformer(nn.Module): AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -1313,18 +1274,12 @@ class Qwen2VLForConditionalGeneration( if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 6a5447ad0fed4..ccf6cc6e5894b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1092,6 +1092,8 @@ class Qwen3NextModel(nn.Module): name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -1108,6 +1110,11 @@ class Qwen3NextModel(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + logger.warning_once( + f"Parameter {name} not found in params_dict, skip loading" + ) + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 635c3bfdc65c7..089129e443c01 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.attention.backends.registry import AttentionBackendEnum from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY @@ -192,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module): mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, + multimodal_config: MultiModalConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: @@ -205,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) self.mlp = Qwen3_VisionMLP( @@ -299,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): vision_config, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -347,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(vision_config.depth) @@ -376,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module): ] ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -1188,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen3Omni_VisionTransformer( vision_config=thinker_config.vision_config, norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) self.quant_config = quant_config diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fcd58c4d33cd7..c0589986d1fe8 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group from vllm.logger import init_logger @@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, + PlaceholderRange, VideoItem, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser @@ -92,6 +99,7 @@ from .interfaces import ( SupportsLoRA, SupportsMRoPE, SupportsMultiModal, + SupportsMultiModalPruning, SupportsPP, _require_is_multimodal, ) @@ -161,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.linear_fc1 = ColumnParallelLinear( in_features, hidden_features, @@ -198,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module): mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, + multimodal_config: MultiModalConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, ) -> None: super().__init__() if norm_layer is None: @@ -213,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend=attn_backend, ) self.mlp = Qwen3_VisionMLP( dim, @@ -223,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module): act_fn=act_fn, bias=True, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -256,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, use_postshuffle_norm: bool = False, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm @@ -305,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module): vision_config: Qwen3VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -318,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module): self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes - self.use_data_parallel = use_data_parallel self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for @@ -351,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module): norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, ) self.deepstack_merger_list = nn.ModuleList( @@ -364,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module): use_postshuffle_norm=True, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - use_data_parallel=use_data_parallel, ) for layer_idx in range(len(self.deepstack_visual_indexes)) ] ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -394,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module): act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, ) for layer_idx in range(vision_config.depth) ] @@ -1043,13 +1059,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] + + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + total_retained = compute_retained_tokens_count( + tokens_per_frame, + len(frames_idx_token), + video_pruning_rate, + ) + if len(frames_idx_token) == 0: + per_frame_token_counts = [] + elif len(frames_idx_token) == 1: + per_frame_token_counts = [tokens_per_frame] + else: + first_frame_tokens = tokens_per_frame + remaining_tokens = max(total_retained - first_frame_tokens, 0) + base = remaining_tokens // (len(frames_idx_token) - 1) + remainder = remaining_tokens % (len(frames_idx_token) - 1) + per_frame_token_counts = [first_frame_tokens] + for frame_idx in range(1, len(frames_idx_token)): + extra = base + (1 if (frame_idx - 1) < remainder else 0) + per_frame_token_counts.append(extra) + placeholder = [] - for frame_idx in frames_idx_token: - placeholder.extend(frame_idx) + for frame_idx, timestamp_tokens in enumerate(frames_idx_token): + placeholder.extend(timestamp_tokens) + tokens_this_frame = per_frame_token_counts[ + frame_idx if frame_idx < len(per_frame_token_counts) else -1 + ] placeholder.extend( [vision_start_token_id] - + [video_token_id] * num_tokens_per_frame + + [video_token_id] * tokens_this_frame + [vision_end_token_id] ) return PromptUpdateDetails.select_token_id(placeholder, video_token_id) @@ -1190,6 +1232,7 @@ class Qwen3VLForConditionalGeneration( SupportsPP, SupportsMRoPE, SupportsEagle3, + SupportsMultiModalPruning, ): packed_modules_mapping = { "qkv_proj": [ @@ -1232,23 +1275,22 @@ class Qwen3VLForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) self.language_model = Qwen3LLMForCausalLM( @@ -1418,6 +1460,109 @@ class Qwen3VLForConditionalGeneration( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) + def _postprocess_image_embeds_evs( + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + + def _postprocess_video_embeds_evs( + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input.get("second_per_grid_ts") + if second_per_grid_ts is None: + # For Qwen3-VL, second_per_grid_ts might not be available + # Use default value of 1.0 for each video + second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) + else: + second_per_grid_ts = second_per_grid_ts.long() + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + + # Debug logging for EVS pruning + logger.debug( + "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " + "pruning_rate=%.2f, reduction=%.1f%%)", + emb.shape[0], + retention_mask.sum().item(), + size[0], + size[1], + size[2], + self.video_pruning_rate, + (1 - retention_mask.float().mean().item()) * 100, + ) + + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: @@ -1440,6 +1585,20 @@ class Qwen3VLForConditionalGeneration( def iter_mm_grid_hw( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[int, int, int]]: + """ + Iterate over multimodal features and yield grid information. + + For videos with EVS (Efficient Video Sampling) enabled, this function + computes the offset based on the pruned token count rather than relying + on input_tokens.index(), which would fail when tokens are pruned. + + Args: + input_tokens: List of token IDs in the prompt + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_h, grid_w) for each frame/image + """ video_token_id = self.config.video_token_id spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): @@ -1452,42 +1611,289 @@ class Qwen3VLForConditionalGeneration( t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + + # Check if EVS (Efficient Video Sampling) is enabled + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + + if is_evs_enabled: + frame_offsets = self._extract_frame_offsets_from_mask( + mm_feature.mm_position, t + ) + if frame_offsets is not None: + for rel_offset in frame_offsets: + yield offset + rel_offset, llm_grid_h, llm_grid_w + continue + + # If EVS is enabled but mask is missing, this indicates a bug + # in the prompt processing pipeline. The is_embed mask should + # always be present when video_pruning_rate > 0. + raise RuntimeError( + f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " + "but is_embed mask is missing from mm_position. " + "This indicates a bug in prompt processing." + ) + else: + # Non-EVS mode: Use original logic with input_tokens.index() + for _ in range(t): + offset = input_tokens.index(video_token_id, offset) + yield offset, llm_grid_h, llm_grid_w + offset += llm_grid_h * llm_grid_w else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def _get_evs_mask_segments( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[torch.Tensor] | None: + """Extract contiguous segments from EVS is_embed mask. + + The EVS (Efficient Video Sampling) mask marks which placeholder + positions should be filled with video embeddings. This method splits + the mask into contiguous segments, where each segment represents one + retained frame. + + This is a pure function - it does not modify any state and always + returns the same output for the same input (idempotent). + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frame segments + + Returns: + List of tensors, each containing indices for one frame segment, + or None if EVS is not enabled or validation fails. + """ + is_embed_mask = getattr(mm_position, "is_embed", None) + if is_embed_mask is None: + return None + + # Find all True positions in the mask + mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) + true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() + if true_indices.numel() == 0: + return None + + # Split into contiguous segments (where diff > 1 indicates a gap) + if true_indices.numel() == 1: + segments = [true_indices] + else: + diffs = torch.diff(true_indices) + split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() + if split_points.numel() == 0: + segments = [true_indices] + else: + segments = torch.tensor_split( + true_indices, split_points.add(1).tolist() + ) + + # Validate segment count matches expected frames + if len(segments) < expected_frames: + logger.debug( + "EVS mask segments (%d) do not match expected frames (%d)", + len(segments), + expected_frames, + ) + return None + + return segments[:expected_frames] + + def _extract_frame_offsets_from_mask( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return relative offsets for each EVS-retained frame. + + The prompt processor stores a boolean mask inside ``mm_position`` that + marks which placeholder locations should be populated with video + embeddings. By splitting that mask into contiguous runs we can recover + the start of every retained frame without probing ``input_tokens``. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of starting offsets (relative to mm_position) for each frame, + or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [int(segment[0].item()) for segment in segments] + + def _get_actual_frame_token_counts( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return actual token count for each EVS-retained frame. + + This function calculates the actual number of tokens per frame by + analyzing the is_embed mask, accounting for EVS pruning. Each frame + may have a different token count due to content-aware pruning. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of token counts for each frame, or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [len(seg) for seg in segments] + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + # Pre-collect actual frame token counts for EVS mode + frame_token_counts_map = {} + for mm_feature in mm_features: + if mm_feature.modality == "video": + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + if is_evs_enabled: + t = mm_feature.data["video_grid_thw"].data.tolist()[0] + token_counts = self._get_actual_frame_token_counts( + mm_feature.mm_position, t + ) + assert token_counts is not None, ( + "EVS enabled but failed to extract frame token counts " + "from is_embed mask" + ) + frame_token_counts_map[mm_feature.mm_position.offset] = token_counts + llm_pos_ids_list = [] st = 0 + frame_counts_idx = {} + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( + + # Determine actual token count for this frame + base_offset = None + for feat_offset in frame_token_counts_map: + if offset >= feat_offset: + base_offset = feat_offset + + if base_offset is not None: + # EVS mode: use actual token count from is_embed mask + assert base_offset in frame_token_counts_map, ( + f"Found base_offset {base_offset} but not in frame_token_counts_map" + ) + + if base_offset not in frame_counts_idx: + frame_counts_idx[base_offset] = 0 + + counts = frame_token_counts_map[base_offset] + idx = frame_counts_idx[base_offset] + + assert idx < len(counts), ( + f"EVS frame index {idx} out of range (total frames: {len(counts)})" + ) + + actual_frame_tokens = counts[idx] + frame_counts_idx[base_offset] += 1 + else: + # Non-EVS mode (or image): use theoretical grid size + actual_frame_tokens = llm_grid_h * llm_grid_w + + # Add text segment + text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(text_positions) + st_idx += text_len + # Add frame segment with actual token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - llm_pos_ids_list.append(grid_indices + text_len + st_idx) - st = offset + llm_grid_h * llm_grid_w + # Only take the first actual_frame_tokens positions + frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx + llm_pos_ids_list.append(frame_positions) + # Update st using actual token count + st = offset + actual_frame_tokens + + # Handle final text segment if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( + final_text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(final_text_positions) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return torch.from_numpy(llm_positions), mrope_position_delta def get_language_model(self) -> torch.nn.Module: @@ -1508,9 +1914,17 @@ class Qwen3VLForConditionalGeneration( multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index a054bd5b3831e..3186804488e57 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -419,6 +419,10 @@ class Qwen3VLMoeForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) if not multimodal_config.get_limit_per_prompt( "image" @@ -429,8 +433,8 @@ class Qwen3VLMoeForConditionalGeneration( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, ) self.language_model = Qwen3MoeLLMForCausalLM( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a4a964bc7c1a6..4575e91e13959 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -264,10 +264,15 @@ _CROSS_ENCODER_MODELS = { _MULTIMODAL_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), + "AudioFlamingo3ForConditionalGeneration": ( + "audioflamingo3", + "AudioFlamingo3ForConditionalGeneration", + ), "AyaVisionForConditionalGeneration": ( "aya_vision", "AyaVisionForConditionalGeneration", ), + "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"), "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ( diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 31cc645099141..45b6e93307ac3 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module): torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type != "absolute": - raise ValueError( - "Only 'absolute' position_embedding_type" + " is supported" - ) - def forward( self, input_ids: torch.Tensor, @@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def _build_model( self, vllm_config: VllmConfig, prefix: str = "" ) -> BertModel | BertWithRope: - if vllm_config.model_config.hf_config.position_embedding_type == "rotary": - return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) + hf_config = vllm_config.model_config.hf_config + kwargs = dict(vllm_config=vllm_config, prefix=prefix) + if getattr(hf_config, "position_embedding_type", "absolute") == "absolute": + return BertModel(**kwargs, embedding_class=RobertaEmbedding) else: - return BertModel( - vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding - ) + return JinaRobertaModel(**kwargs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bbce01995412c..efdee255ab5eb 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -6,14 +6,14 @@ within a vision language model.""" from collections.abc import Iterable import torch -from einops import rearrange, repeat from torch import nn from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -25,11 +25,12 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform -from .vision import get_vit_attn_backend - class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: @@ -147,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module): return patch_embeds -# copy from flash_attn/layers/rotary.py -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, @@ -190,14 +157,20 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend and not current_platform.is_xpu(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - apply_rotary_emb_func = apply_rotary_emb + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) + + if is_flash_attn_backend and not current_platform.is_cuda(): + apply_rotary_emb_func = apply_rotary_emb.forward_cuda else: - apply_rotary_emb_func = apply_rotary_emb_torch - q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) + apply_rotary_emb_func = apply_rotary_emb.forward_native + + q_embed = apply_rotary_emb_func(q, cos, sin) + k_embed = apply_rotary_emb_func(k, cos, sin) + return q_embed, k_embed @@ -208,6 +181,7 @@ class Siglip2Attention(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend_override: AttentionBackendEnum | None = None, @@ -227,20 +201,25 @@ class Siglip2Attention(nn.Module): self.dropout = config.attention_dropout self.is_causal = False - # TODO(Isotr0py): Enable data parallel after we support - # disabling TP on parallel linear layer + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, total_num_heads=self.num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, ) self.tp_size = ( @@ -249,31 +228,13 @@ class Siglip2Attention(nn.Module): self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.use_rope = config.use_rope - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_heads_per_partition, head_size=self.head_dim, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - self.attn_backend = AttentionBackendEnum.TORCH_SDPA - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def forward( self, hidden_states: torch.Tensor, @@ -298,46 +259,23 @@ class Siglip2Attention(nn.Module): keys.unsqueeze(0), cos, sin, - self.is_flash_attn_backend, + self.attn.is_flash_attn_backend, ) queries = queries.squeeze(0) keys = keys.squeeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - if self.is_flash_attn_backend: - attn_output = self.flash_attn_varlen_func( - queries, - keys, - values, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - ).reshape(seq_length, -1) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - batch_size = cu_seqlens.shape[0] - 1 - outputs = [] - cu = cu_seqlens.tolist() - for i in range(batch_size): - start_idx = cu[i] - end_idx = cu[i + 1] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output = self.attn( + query=queries.unsqueeze(0), + key=keys.unsqueeze(0), + value=values.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + attn_output = attn_output.reshape( + seq_length, self.num_heads_per_partition * self.head_dim + ) - # Each sequence is processed independently. - q_i = queries[start_idx:end_idx].unsqueeze(0) - k_i = keys[start_idx:end_idx].unsqueeze(0) - v_i = values[start_idx:end_idx].unsqueeze(0) - - # (1, seq_len, num_heads, head_dim) -> - # (1, num_heads, seq_len, head_dim) - q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] - - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) - output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1) - outputs.append(output_i) - - attn_output = torch.cat(outputs, dim=0) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -347,25 +285,30 @@ class Siglip2MLP(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() self.config = config + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.activation_fn = get_act_fn(config.hidden_act) - # TODO(Isotr0py): Enable data parallel after we support - # disabling TP on parallel linear layer self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -380,9 +323,8 @@ class Siglip2EncoderLayer(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -390,16 +332,15 @@ class Siglip2EncoderLayer(nn.Module): self.self_attn = Siglip2Attention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -444,9 +385,8 @@ class Siglip2Encoder(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -455,9 +395,8 @@ class Siglip2Encoder(nn.Module): Siglip2EncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{idx}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) for idx in range(config.num_hidden_layers) ] @@ -630,9 +569,8 @@ class Siglip2VisionTransformer(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -642,9 +580,8 @@ class Siglip2VisionTransformer(nn.Module): self.encoder = Siglip2Encoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -671,18 +608,16 @@ class Siglip2NavitModel(torch.nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = Siglip2VisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) def forward( diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index a74fd80c06d8c..fbf5594851ece 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, ) -> tuple[torch.Tensor, ...]: batch_size, dim, num_channels = hidden_states.shape @@ -201,12 +200,9 @@ class SwinAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: - self_outputs = self.self( - hidden_states, attention_mask, head_mask, output_attentions - ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -339,18 +335,14 @@ class SwinStage(nn.Module): self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, always_partition: bool | None = False, ) -> tuple[torch.Tensor]: height, width = input_dimensions for i, layer_module in enumerate(self.blocks): - layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module( hidden_states, input_dimensions, - layer_head_mask, output_attentions, always_partition, ) @@ -425,17 +417,13 @@ class SwinEncoder(nn.Module): self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, always_partition: bool | None = False, ) -> tuple[torch.Tensor]: for i, layer_module in enumerate(self.layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module( hidden_states, input_dimensions, - layer_head_mask, output_attentions, always_partition, ) @@ -473,7 +461,6 @@ class SwinModel(nn.Module): def forward( self, pixel_values: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = None, ) -> tuple[torch.Tensor]: embedding_output, input_dimensions = self.embeddings(pixel_values) @@ -481,7 +468,6 @@ class SwinModel(nn.Module): encoder_outputs = self.encoder( embedding_output, input_dimensions, - head_mask=head_mask, output_attentions=output_attentions, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 32a2ba1ef38f7..7e1b7c90c9204 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -5,6 +5,7 @@ """PyTorch Ultravox model.""" import copy +import inspect from collections.abc import Iterable, Mapping, Sequence from types import SimpleNamespace from typing import Annotated, Any, Literal, TypeAlias @@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin): ) hidden_states = hidden_states + positions + # Backward compatibility for Transformers v4 where layer_head_mask + # was a required argument for WhisperEncoderLayer.forward + kwargs = {} + if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters: + kwargs["layer_head_mask"] = None + for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=extended_attention_mask, - layer_head_mask=None, + **kwargs, ) hidden_states = layer_outputs[0] @@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder): attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) + # Backward compatibility for Transformers v4 where layer_head_mask + # was a required argument for WhisperEncoderLayer.forward + kwargs = {} + if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters: + kwargs["layer_head_mask"] = None + for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, - layer_head_mask=None, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 7602eca9c3257..024c50f1207ed 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -11,7 +11,7 @@ import torch from transformers import PretrainedConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -88,14 +88,11 @@ def get_vit_attn_backend( """ Get the available attention backend for Vision Transformer. """ - if attn_backend_override is not None: - return attn_backend_override - - selected_backend = get_current_vllm_config().attention_config.backend - if selected_backend is not None: - return selected_backend - - return current_platform.get_vit_attn_backend(head_size, dtype) + return current_platform.get_vit_attn_backend( + head_size, + dtype, + backend=attn_backend_override, + ) def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 7b408248ec74c..331f0c54ecfbc 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -51,7 +51,8 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .utils import init_vllm_registered_model, maybe_prefix diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 062547401c3cf..51b8f77f29088 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -127,13 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return tensor2base64(media) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 789421e9e0c3b..1506ecb8c7aa0 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -122,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return pybase64.b64encode(media.numpy()).decode("utf-8") diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index c3c7cc2c3da0e..a69afc3176cab 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -120,7 +120,7 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): return self.data[index] def get_processor_data(self) -> Mapping[str, object]: - return {f"{self.modality}s": self.data} + return {f"{self.modality}s": self.get_all()} def get_passthrough_data(self) -> Mapping[str, object]: return {} diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index d961dcf13e53e..e1b461d79a655 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -126,21 +127,13 @@ class CpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: logger.info("Cannot use %s backend on CPU.", selected_backend) - if use_mla: + if attn_selector_config.use_mla: raise NotImplementedError("MLA is not supported on CPU.") - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on CPU.") return AttentionBackendEnum.CPU_ATTN.get_path() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 915392a4125f9..2dc4ba5d70cac 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -7,14 +7,13 @@ pynvml. However, it should not initialize cuda context. import os from collections.abc import Callable from functools import cache, wraps -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml @@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import CacheDType else: @@ -182,7 +182,7 @@ class CudaPlatformBase(Platform): if vllm_config.attention_config.backend is None: # Default case - if cls.is_device_capability(100) and not use_sparse: + if cls.is_device_capability_family(100) and not use_sparse: # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). use_cutlass_mla = True # Set the backend in AttentionConfig so it's used during @@ -255,36 +255,11 @@ class CudaPlatformBase(Platform): torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - # Try FlashAttention first - if (cc := cls.get_device_capability()) and cc.major >= 8: - try: - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size( - head_size - ) and backend_class.supports_dtype(dtype): - return AttentionBackendEnum.FLASH_ATTN - except ImportError: - pass - - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_valid_backends( cls, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability: DeviceCapability, + attn_selector_config: "AttentionSelectorConfig", ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -292,21 +267,15 @@ class CudaPlatformBase(Platform): valid_backends_priorities = [] invalid_reasons = {} - backend_priorities = _get_backend_priorities(use_mla, device_capability) + backend_priorities = _get_backend_priorities( + attn_selector_config.use_mla, device_capability + ) for priority, backend in enumerate(backend_priorities): try: backend_class = backend.get_class() invalid_reasons_i = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -321,37 +290,19 @@ class CudaPlatformBase(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int | None, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if attn_type is None: - attn_type = AttentionType.DECODER - device_capability = cls.get_device_capability() assert device_capability is not None + attn_selector_config = attn_selector_config._replace(block_size=None) # First try checking just the selected backend, if there is one. if selected_backend is not None: try: backend_class = selected_backend.get_class() invalid_reasons = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons = ["ImportError"] @@ -367,16 +318,8 @@ class CudaPlatformBase(Platform): # No selected backend or the selected backend is invalid, # so we try finding a valid backend. valid_backends_priorities, invalid_reasons = cls.get_valid_backends( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + attn_selector_config=attn_selector_config, ) reasons_str = ( "{" @@ -386,11 +329,7 @@ class CudaPlatformBase(Platform): ) + "}" ) - config_str = ( - f"head_size: {head_size}, dtype: {dtype}, " - f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " - f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" - ) + config_str = attn_selector_config.__repr__() logger.debug_once( f"Some attention backends are not valid for {cls.device_name} with " f"{config_str}. Reasons: {reasons_str}." @@ -409,14 +348,50 @@ class CudaPlatformBase(Platform): ) selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] - logger.info( + logger.info_once( "Using %s attention backend out of potential backends: %s", selected_backend.name, - [b[0].name for b in valid_backends_priorities], + tuple(b[0].name for b in valid_backends_priorities), + scope="local", ) return selected_backend.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + # Try FlashAttention first + if (cc := cls.get_device_capability()) and cc.major >= 8: + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass + + return AttentionBackendEnum.TORCH_SDPA + @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f04e94e425257..d4b40045df384 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import platform import random import sys from datetime import timedelta -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, Optional import numpy as np import torch @@ -18,8 +18,8 @@ from vllm.logger import init_logger if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig - from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -222,29 +222,52 @@ class Platform: with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: """Get the attention backend class of a device.""" return "" + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + """ + Get the vision attention backend class of a device. + + NOTE: ViT Attention should be checked and override in the platform-specific + implementation. we should not override this in any other places, like + the model_executor/models/.py. + + We check if the backend is None or not: + 1. If not, check if the backend is supported by the platform. + 2. If None, continue to the default selection logic. + """ + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention" + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + logger.info_once( + f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention" + ) + return AttentionBackendEnum.TORCH_SDPA + @classmethod def get_device_capability( cls, @@ -301,6 +324,21 @@ class Platform: return current_capability.to_int() == capability + @classmethod + def is_device_capability_family( + cls, + capability: int, + device_id: int = 0, + ) -> bool: + """ + Returns True if the device capability is any .x. + Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x). + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + return (current_capability.to_int() // 10) == (capability // 10) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 876114c2d33a4..c237f7cf887c1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -3,7 +3,7 @@ import os from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig logger = init_logger(__name__) @@ -123,8 +124,6 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: - from vllm._aiter_ops import rocm_aiter_ops - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -140,7 +139,6 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -187,42 +185,19 @@ class RocmPlatform(Platform): if not on_gfx9(): supported_quantization += ["bitsandbytes"] - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> AttentionBackendEnum: - from importlib.util import find_spec - - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_mha_enabled(): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return AttentionBackendEnum.ROCM_AITER_FA - - if on_gfx9() and find_spec("flash_attn") is not None: - return AttentionBackendEnum.FLASH_ATTN - - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_attn_backend_cls( cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - attn_type: str | None = None, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm._aiter_ops import rocm_aiter_ops - if use_sparse: - if kv_cache_dtype.startswith("fp8"): + block_size = attn_selector_config.block_size + kv_cache_dtype = attn_selector_config.kv_cache_dtype + + if attn_selector_config.use_sparse: + if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) @@ -232,7 +207,7 @@ class RocmPlatform(Platform): logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() - if use_mla: + if attn_selector_config.use_mla: if selected_backend is None: selected_backend = ( AttentionBackendEnum.ROCM_AITER_MLA @@ -322,6 +297,43 @@ class RocmPlatform(Platform): "ROCm. Note that V0 attention backends have been removed." ) + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + from importlib.util import find_spec + + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return AttentionBackendEnum.ROCM_AITER_FA + + if on_gfx9() and find_spec("flash_attn") is not None: + return AttentionBackendEnum.FLASH_ATTN + + return AttentionBackendEnum.TORCH_SDPA + @classmethod def set_device(cls, device: torch.device) -> None: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d6998e7a308af..7c479bf2b6a0e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import torch from tpu_info import device @@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: from typing import TypeAlias + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -57,17 +58,9 @@ class TpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) @@ -75,6 +68,32 @@ class TpuPlatform(Platform): logger.info("Using Pallas V1 backend.") return AttentionBackendEnum.PALLAS.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.PALLAS, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention" + f"Supported backends are: {cls.get_supported_vit_attn_backends()}." + ) + logger.info_once(f"Using backend {backend} for vit attention.") + return backend + + logger.info_once( + f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention." + ) + return AttentionBackendEnum.PALLAS + @classmethod def set_device(cls, device: torch.device) -> None: """ diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 0a05750764d8d..af8979af36643 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -3,7 +3,7 @@ import contextlib import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -14,6 +14,7 @@ from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -42,15 +43,7 @@ class XPUPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout @@ -60,7 +53,7 @@ class XPUPlatform(Platform): "only NHD layout is supported by XPU attention kernels." ) - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") @@ -71,12 +64,40 @@ class XPUPlatform(Platform): elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " - f"with use_mla: {use_mla}" + f"with use_mla: {attn_selector_config.use_mla}" ) logger.info("Using Flash Attention backend.") return AttentionBackendEnum.FLASH_ATTN.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + # XPU only supports FLASH_ATTN for vision attention. + return [ + AttentionBackendEnum.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: " + f"{cls.get_supported_vit_attn_backends()}." + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + logger.info_once( + f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention" + ) + return AttentionBackendEnum.FLASH_ATTN + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -110,12 +131,6 @@ class XPUPlatform(Platform): device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - return AttentionBackendEnum.FLASH_ATTN - @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/profiler/wrapper.py b/vllm/profiler/wrapper.py index a44a6a5eea0dd..f891a88f90394 100644 --- a/vllm/profiler/wrapper.py +++ b/vllm/profiler/wrapper.py @@ -61,7 +61,7 @@ class WorkerProfiler(ABC): """Call _stop with error handling but no safeguards.""" try: self._stop() - logger.info("Profiler stopped successfully.") + logger.info_once("Profiler stopped successfully.", scope="local") except Exception as e: logger.warning("Failed to stop profiler: %s", e) self._running = False # Always mark as not running, assume stop worked @@ -91,7 +91,7 @@ class WorkerProfiler(ABC): and self._delay_iters > 0 and self._active_iteration_count == self._delay_iters ): - logger.info("Starting profiler after delay...") + logger.info_once("Starting profiler after delay...", scope="local") self._call_start() if self._running: @@ -105,7 +105,9 @@ class WorkerProfiler(ABC): # Automatically stop the profiler after max iters # will be marked as not running, but leave as active so that stop # can clean up properly - logger.info("Max profiling iterations reached. Stopping profiler...") + logger.info_once( + "Max profiling iterations reached. Stopping profiler...", scope="local" + ) self._call_stop() return @@ -125,7 +127,7 @@ class WorkerProfiler(ABC): def shutdown(self) -> None: """Ensure profiler is stopped when shutting down.""" - logger.info_once("Shutting down profiler") + logger.info_once("Shutting down profiler", scope="local") if self._running: self.stop() @@ -156,9 +158,10 @@ class TorchProfilerWrapper(WorkerProfiler): self.profiler_config = profiler_config torch_profiler_trace_dir = profiler_config.torch_profiler_dir if local_rank in (None, 0): - logger.info( + logger.info_once( "Torch profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir, + scope="local", ) logger.debug( "Profiler config: record_shapes=%s," diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index 3206dbb29fe2e..de3d1296ec734 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -10,7 +10,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.logger import init_logger from vllm.reasoning import ReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index 67a6d7c8eb3d9..31e74b1a16e20 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .deepseekv32 import DeepseekV32Tokenizer -from .hf import HfTokenizer -from .mistral import MistralTokenizer from .protocol import TokenizerLike from .registry import ( TokenizerRegistry, @@ -15,12 +12,9 @@ from .registry import ( __all__ = [ "TokenizerLike", - "HfTokenizer", - "MistralTokenizer", "TokenizerRegistry", "cached_get_tokenizer", "get_tokenizer", "cached_tokenizer_from_config", "init_tokenizer_from_config", - "DeepseekV32Tokenizer", ] diff --git a/vllm/tokenizers/deepseekv32.py b/vllm/tokenizers/deepseek_v32.py similarity index 81% rename from vllm/tokenizers/deepseekv32.py rename to vllm/tokenizers/deepseek_v32.py index a7fa0f421725a..bf279a5cf67c5 100644 --- a/vllm/tokenizers/deepseekv32.py +++ b/vllm/tokenizers/deepseek_v32.py @@ -2,24 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path +from typing import Any from transformers import BatchEncoding +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from .deepseek_v32_encoding import encode_messages -from .hf import HfTokenizer, TokenizerLike -from .registry import TokenizerRegistry +from .hf import CachedHfTokenizer +from .protocol import TokenizerLike -@TokenizerRegistry.register("deepseek_v32") -class DeepseekV32Tokenizer(HfTokenizer): - def __init__(self, tokenizer: TokenizerLike): - self.tokenizer = tokenizer - self.name_or_path = ( - tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else "" - ) - self._added_vocab = self.tokenizer.get_added_vocab() - self._added_vocab_size = len(self._added_vocab) - +class DeepseekV32Tokenizer(CachedHfTokenizer): @classmethod def from_pretrained( cls, @@ -40,7 +34,21 @@ class DeepseekV32Tokenizer(HfTokenizer): ) return DeepseekV32Tokenizer(tokenizer) - def apply_chat_template(self, messages, tools=None, **kwargs): + def __init__(self, tokenizer: TokenizerLike) -> None: + super().__init__() + + self.tokenizer = tokenizer + self.name_or_path = getattr(tokenizer, "name_or_path", "") + + self._added_vocab = self.tokenizer.get_added_vocab() + self._added_vocab_size = len(self._added_vocab) + + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + **kwargs, + ) -> str | list[int]: thinking = kwargs.get("thinking", False) thinking_mode = "thinking" if not thinking: @@ -49,13 +57,24 @@ class DeepseekV32Tokenizer(HfTokenizer): messages = conversation.copy() if tools is not None and len(tools) > 0: messages.insert(0, {"role": "system"}) - messages[0]["tools"] = tools + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] # Historical reasoning content is dropped when a new user message is introduced drop_thinking = messages[-1]["role"] == "user" encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) prompt_str = encode_messages(messages, **encode_config) # type: ignore + + if kwargs.get("tokenize", True): + tokenizer_kwargs = { + k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs + } + return self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) + return prompt_str def num_special_tokens_to_add(self) -> int: diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 3445073120387..a7b565dca5d8f 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -3,22 +3,18 @@ import contextlib import copy from pathlib import Path -from typing import TYPE_CHECKING +from typing import TypeAlias -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from .protocol import TokenizerLike -from .registry import TokenizerRegistry -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast -def get_cached_tokenizer( - tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast", -) -> TokenizerLike: +def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: """ By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. @@ -65,11 +61,10 @@ def get_cached_tokenizer( CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer # type: ignore + return cached_tokenizer -@TokenizerRegistry.register("hf") -class HfTokenizer(TokenizerLike): +class CachedHfTokenizer(TokenizerLike): @classmethod def from_pretrained( cls, @@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike): revision: str | None = None, download_dir: str | None = None, **kwargs, - ) -> "TokenizerLike": + ) -> HfTokenizer: try: tokenizer = AutoTokenizer.from_pretrained( path_or_repo_id, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 1f44037dd55ec..534b0da484a5d 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -3,10 +3,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, cast +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.logger import init_logger from .protocol import TokenizerLike -from .registry import TokenizerRegistry if TYPE_CHECKING: from mistral_common.protocol.instruct.request import ( @@ -15,9 +16,6 @@ if TYPE_CHECKING: from mistral_common.tokens.tokenizers.tekken import Tekkenizer from transformers import BatchEncoding - from vllm.entrypoints.chat_utils import ChatCompletionMessageParam - from vllm.entrypoints.openai.protocol import ChatCompletionRequest - try: # Transformers v5 from transformers.tokenization_mistral_common import MistralCommonBackend @@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: return tokenizer.unk_id -@TokenizerRegistry.register("mistral") class MistralTokenizer(TokenizerLike): @classmethod def from_pretrained( diff --git a/vllm/tokenizers/protocol.py b/vllm/tokenizers/protocol.py index d6a3b0ba9b5f5..28754f9e10d00 100644 --- a/vllm/tokenizers/protocol.py +++ b/vllm/tokenizers/protocol.py @@ -97,7 +97,7 @@ class TokenizerLike(Protocol): messages: list["ChatCompletionMessageParam"], tools: list[dict[str, Any]] | None = None, **kwargs, - ) -> list[int]: + ) -> str | list[int]: raise NotImplementedError def convert_tokens_to_string(self, tokens: list[str]) -> str: diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 1d44feeee500f..72447ef04e87c 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.util -from collections.abc import Callable +from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING import huggingface_hub -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never, deprecated import vllm.envs as envs from vllm.logger import init_logger @@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname from .protocol import TokenizerLike if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config.model import ModelConfig, RunnerType logger = init_logger(__name__) -_T = TypeVar("_T", bound=type[TokenizerLike]) + +_VLLM_TOKENIZERS = { + "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), + "hf": ("hf", "CachedHfTokenizer"), + "mistral": ("mistral", "MistralTokenizer"), +} -class TokenizerRegistry: - # Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class) - REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {} +@dataclass +class _TokenizerRegistry: + # Tokenizer mode -> (tokenizer module, tokenizer class) + tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict) - # In-tree tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str) -> Callable[[_T], _T]: ... - - # OOT tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str, module: str, class_name: str) -> None: ... - - @staticmethod - def register( - tokenizer_mode: str, - module: str | None = None, - class_name: str | None = None, - ) -> Callable[[_T], _T] | None: - # In-tree tokenizers - if module is None or class_name is None: - - def wrapper(tokenizer_cls: _T) -> _T: - assert tokenizer_mode not in TokenizerRegistry.REGISTRY - TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls - - return tokenizer_cls - - return wrapper - - # OOT tokenizers - if tokenizer_mode in TokenizerRegistry.REGISTRY: + def register(self, tokenizer_mode: str, module: str, class_name: str) -> None: + if tokenizer_mode in self.tokenizers: logger.warning( "%s.%s is already registered for tokenizer_mode=%r. " "It is overwritten by the new one.", @@ -72,36 +51,42 @@ class TokenizerRegistry: tokenizer_mode, ) - TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name) + self.tokenizers[tokenizer_mode] = (module, class_name) return None - @staticmethod - def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike": - if tokenizer_mode not in TokenizerRegistry.REGISTRY: + def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]: + if tokenizer_mode not in self.tokenizers: raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.") - item = TokenizerRegistry.REGISTRY[tokenizer_mode] - if isinstance(item, type): - return item.from_pretrained(*args, **kwargs) - - module, class_name = item + module, class_name = self.tokenizers[tokenizer_mode] logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}") - class_ = resolve_obj_by_qualname(f"{module}.{class_name}") - return class_.from_pretrained(*args, **kwargs) + return resolve_obj_by_qualname(f"{module}.{class_name}") + + def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: + tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode) + return tokenizer_cls.from_pretrained(*args, **kwargs) -def get_tokenizer( +TokenizerRegistry = _TokenizerRegistry( + { + mode: (f"vllm.tokenizers.{mod_relname}", cls_name) + for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items() + } +) + + +def resolve_tokenizer_args( tokenizer_name: str | Path, *args, + runner_type: "RunnerType" = "generate", tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - revision: str | None = None, - download_dir: str | None = None, **kwargs, -) -> TokenizerLike: - """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" +): + revision: str | None = kwargs.get("revision") + download_dir: str | None = kwargs.get("download_dir") + if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -125,16 +110,6 @@ def get_tokenizer( ) tokenizer_name = tokenizer_path - if tokenizer_mode == "slow": - if kwargs.get("use_fast", False): - raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") - - tokenizer_mode = "hf" - kwargs["use_fast"] = False - - if "truncation_side" not in kwargs: - kwargs["truncation_side"] = "left" - # Separate model folder from file path for GGUF models if is_gguf(tokenizer_name): if check_gguf_file(tokenizer_name): @@ -150,6 +125,21 @@ def get_tokenizer( ) kwargs["gguf_file"] = gguf_file + if "truncation_side" not in kwargs: + if runner_type == "generate" or runner_type == "draft": + kwargs["truncation_side"] = "left" + elif runner_type == "pooling": + kwargs["truncation_side"] = "right" + else: + assert_never(runner_type) + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + + tokenizer_mode = "hf" + kwargs["use_fast"] = False + # Try to use official Mistral tokenizer if possible if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): allow_patterns = ["tekken.json", "tokenizer.model.v*"] @@ -165,49 +155,70 @@ def get_tokenizer( if tokenizer_mode == "auto": tokenizer_mode = "hf" - tokenizer_args = (tokenizer_name, *args) - tokenizer_kwargs = dict( + return tokenizer_mode, tokenizer_name, args, kwargs + + +cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args) + + +def tokenizer_args_from_config(config: "ModelConfig", **kwargs): + return cached_resolve_tokenizer_args( + config.tokenizer, + runner_type=config.runner_type, + tokenizer_mode=config.tokenizer_mode, + revision=config.tokenizer_revision, + trust_remote_code=config.trust_remote_code, + **kwargs, + ) + + +_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike) + + +def get_tokenizer( + tokenizer_name: str | Path, + *args, + tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment] + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, +) -> _T: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" + tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args( + tokenizer_name, + *args, trust_remote_code=trust_remote_code, revision=revision, download_dir=download_dir, **kwargs, ) - if tokenizer_mode == "custom": - logger.warning_once( - "TokenizerRegistry now uses `tokenizer_mode` as the registry key " - "instead of `tokenizer_name`. " - "Please update the definition of `.from_pretrained` in " - "your custom tokenizer to accept `args=%s`, `kwargs=%s`. " - "Then, you can pass `tokenizer_mode=%r` instead of " - "`tokenizer_mode='custom'` when initializing vLLM.", - tokenizer_args, - str(tokenizer_kwargs), - tokenizer_name, - ) + if tokenizer_cls == TokenizerLike: + tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode) + else: + tokenizer_cls_ = tokenizer_cls - tokenizer_mode = str(tokenizer_name) - - tokenizer = TokenizerRegistry.get_tokenizer( - tokenizer_mode, - *tokenizer_args, - **tokenizer_kwargs, - ) + tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs) if not tokenizer.is_fast: logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - return tokenizer + return tokenizer # type: ignore cached_get_tokenizer = lru_cache(get_tokenizer) def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): + if model_config.skip_tokenizer_init: + return None + return cached_get_tokenizer( model_config.tokenizer, + runner_type=model_config.runner_type, tokenizer_mode=model_config.tokenizer_mode, revision=model_config.tokenizer_revision, trust_remote_code=model_config.trust_remote_code, @@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): ) +@deprecated( + "Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14." +) def init_tokenizer_from_config(model_config: "ModelConfig"): - runner_type = model_config.runner_type - if runner_type == "generate" or runner_type == "draft": - truncation_side = "left" - elif runner_type == "pooling": - truncation_side = "right" - else: - assert_never(runner_type) - - return get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=truncation_side, - ) + return cached_tokenizer_from_config(model_config) diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py new file mode 100644 index 0000000000000..181d8bcba9553 --- /dev/null +++ b/vllm/tool_parsers/__init__.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) + +__all__ = ["ToolParser", "ToolParserManager"] + + +""" +Register a lazy module mapping. + +Example: + ToolParserManager.register_lazy_module( + name="kimi_k2", + module_path="vllm.tool_parsers.kimi_k2_parser", + class_name="KimiK2ToolParser", + ) +""" + + +_TOOL_PARSERS_TO_REGISTER = { + "deepseek_v3": ( # name + "deepseekv3_tool_parser", # filename + "DeepSeekV3ToolParser", # class_name + ), + "deepseek_v31": ( + "deepseekv31_tool_parser", + "DeepSeekV31ToolParser", + ), + "deepseek_v32": ( + "deepseekv32_tool_parser", + "DeepSeekV32ToolParser", + ), + "ernie45": ( + "ernie45_tool_parser", + "Ernie45ToolParser", + ), + "glm45": ( + "glm4_moe_tool_parser", + "Glm4MoeModelToolParser", + ), + "granite-20b-fc": ( + "granite_20b_fc_tool_parser", + "Granite20bFCToolParser", + ), + "granite": ( + "granite_tool_parser", + "GraniteToolParser", + ), + "hermes": ( + "hermes_tool_parser", + "Hermes2ProToolParser", + ), + "hunyuan_a13b": ( + "hunyuan_a13b_tool_parser", + "HunyuanA13BToolParser", + ), + "internlm": ( + "internlm2_tool_parser", + "Internlm2ToolParser", + ), + "jamba": ( + "jamba_tool_parser", + "JambaToolParser", + ), + "kimi_k2": ( + "kimi_k2_tool_parser", + "KimiK2ToolParser", + ), + "llama3_json": ( + "llama_tool_parser", + "Llama3JsonToolParser", + ), + "llama4_json": ( + "llama_tool_parser", + "Llama3JsonToolParser", + ), + "llama4_pythonic": ( + "llama4_pythonic_tool_parser", + "Llama4PythonicToolParser", + ), + "longcat": ( + "longcat_tool_parser", + "LongcatFlashToolParser", + ), + "minimax_m2": ( + "minimax_m2_tool_parser", + "MinimaxM2ToolParser", + ), + "minimax": ( + "minimax_tool_parser", + "MinimaxToolParser", + ), + "mistral": ( + "mistral_tool_parser", + "MistralToolParser", + ), + "olmo3": ( + "olmo3_tool_parser", + "Olmo3PythonicToolParser", + ), + "openai": ( + "openai_tool_parser", + "OpenAIToolParser", + ), + "phi4_mini_json": ( + "phi4mini_tool_parser", + "Phi4MiniJsonToolParser", + ), + "pythonic": ( + "pythonic_tool_parser", + "PythonicToolParser", + ), + "qwen3_coder": ( + "qwen3coder_tool_parser", + "Qwen3CoderToolParser", + ), + "qwen3_xml": ( + "qwen3xml_tool_parser", + "Qwen3XMLToolParser", + ), + "seed_oss": ( + "seed_oss_tool_parser", + "SeedOssToolParser", + ), + "step3": ( + "step3_tool_parser", + "Step3ToolParser", + ), + "xlam": ( + "xlam_tool_parser", + "xLAMToolParser", + ), + "gigachat3": ( + "gigachat3_tool_parser", + "GigaChat3ToolParser", + ), +} + + +def register_lazy_tool_parsers(): + for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items(): + module_path = f"vllm.tool_parsers.{file_name}" + ToolParserManager.register_lazy_module(name, module_path, class_name) + + +register_lazy_tool_parsers() diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py rename to vllm/tool_parsers/abstract_tool_parser.py index 87ef2e0786a94..e2ccb1dad9907 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/tool_parsers/abstract_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( ResponsesRequest, ResponseTextConfig, ) -from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools from vllm.logger import init_logger from vllm.sampling_params import ( StructuredOutputsParams, ) from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.utils import get_json_schema_from_tools from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import import_from_path @@ -203,7 +203,7 @@ class ToolParserManager: Example: ToolParserManager.register_lazy_module( name="kimi_k2", - module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", + module_path="vllm.tool_parsers.kimi_k2_parser", class_name="KimiK2ToolParser", ) """ diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/tool_parsers/deepseekv31_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py rename to vllm/tool_parsers/deepseekv31_tool_parser.py index 10de3dabf985c..33383e1bc0739 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/tool_parsers/deepseekv31_tool_parser.py @@ -15,11 +15,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py rename to vllm/tool_parsers/deepseekv32_tool_parser.py index 4973deb7cefa8..db081178fdeae 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -17,11 +17,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/tool_parsers/deepseekv3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py rename to vllm/tool_parsers/deepseekv3_tool_parser.py index 66b14875dce25..f8cf559f2284a 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/tool_parsers/deepseekv3_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py b/vllm/tool_parsers/ernie45_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py rename to vllm/tool_parsers/ernie45_tool_parser.py index d054d8e4b8651..79193787b3b3b 100644 --- a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py +++ b/vllm/tool_parsers/ernie45_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py b/vllm/tool_parsers/gigachat3_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py rename to vllm/tool_parsers/gigachat3_tool_parser.py index dd27ffa83cfc4..27a6bc1a7bad8 100644 --- a/vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py +++ b/vllm/tool_parsers/gigachat3_tool_parser.py @@ -16,9 +16,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/tool_parsers/glm4_moe_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py rename to vllm/tool_parsers/glm4_moe_tool_parser.py index 165346adb3d93..d254fcb5240a5 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/tool_parsers/glm4_moe_tool_parser.py @@ -18,11 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/tool_parsers/granite_20b_fc_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py rename to vllm/tool_parsers/granite_20b_fc_tool_parser.py index df1b590526b1a..d841fb57ac87e 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/tool_parsers/granite_20b_fc_tool_parser.py @@ -19,17 +19,17 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( consume_space, find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/tool_parsers/granite_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py rename to vllm/tool_parsers/granite_tool_parser.py index 14b0ca0abe357..7abfdf72849d9 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/tool_parsers/granite_tool_parser.py @@ -17,17 +17,17 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( consume_space, find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/tool_parsers/hermes_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py rename to vllm/tool_parsers/hermes_tool_parser.py index 19c1c83268ed4..4b1dea7edf27a 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/tool_parsers/hermes_tool_parser.py @@ -18,11 +18,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/tool_parsers/hunyuan_a13b_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py rename to vllm/tool_parsers/hunyuan_a13b_tool_parser.py index d2419b5d84ead..c739821368042 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/tool_parsers/hunyuan_a13b_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import consume_space from vllm.utils import random_uuid logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/tool_parsers/internlm2_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py rename to vllm/tool_parsers/internlm2_tool_parser.py index 67788358543e9..e87efe3275a71 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/tool_parsers/internlm2_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/tool_parsers/jamba_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py rename to vllm/tool_parsers/jamba_tool_parser.py index 4655da8dd4542..7f3de0b38a33c 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/tool_parsers/jamba_tool_parser.py @@ -18,10 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/tool_parsers/kimi_k2_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py rename to vllm/tool_parsers/kimi_k2_tool_parser.py index 07db52ebd5af1..c215b7978854e 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/tool_parsers/kimi_k2_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/tool_parsers/llama4_pythonic_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py rename to vllm/tool_parsers/llama4_pythonic_tool_parser.py index 1d6de9244066e..3c5409bbfaf42 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/tool_parsers/llama4_pythonic_tool_parser.py @@ -18,10 +18,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/tool_parsers/llama_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py rename to vllm/tool_parsers/llama_tool_parser.py index e1fe6e90dfd0b..b0dfe05c8e556 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/tool_parsers/llama_tool_parser.py @@ -20,15 +20,15 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py b/vllm/tool_parsers/longcat_tool_parser.py similarity index 93% rename from vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py rename to vllm/tool_parsers/longcat_tool_parser.py index 76d76a4aa35a1..72f13559a9222 100644 --- a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +++ b/vllm/tool_parsers/longcat_tool_parser.py @@ -3,8 +3,8 @@ import regex as re -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser class LongcatFlashToolParser(Hermes2ProToolParser): diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/tool_parsers/minimax_m2_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py rename to vllm/tool_parsers/minimax_m2_tool_parser.py index b595a98f35555..dcb2b64f6e73c 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +++ b/vllm/tool_parsers/minimax_m2_tool_parser.py @@ -17,11 +17,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/tool_parsers/minimax_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py rename to vllm/tool_parsers/minimax_tool_parser.py index 1025041037c6e..86e1433c6e360 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/tool_parsers/minimax_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py rename to vllm/tool_parsers/mistral_tool_parser.py index bc827f045606c..49a175f69f434 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -6,6 +6,7 @@ from collections.abc import Sequence from enum import Enum, auto from random import choices from string import ascii_letters, digits +from typing import Any import ijson import regex as re @@ -20,11 +21,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike logger = init_logger(__name__) @@ -84,6 +86,7 @@ class MistralToolParser(ToolParser): # initialize properties used for state when parsing tool calls in # streaming mode + self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START diff --git a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py b/vllm/tool_parsers/olmo3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py rename to vllm/tool_parsers/olmo3_tool_parser.py index baff33bd7e8ac..8cd6a84a9f6b1 100644 --- a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py +++ b/vllm/tool_parsers/olmo3_tool_parser.py @@ -18,10 +18,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/tool_parsers/openai_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py rename to vllm/tool_parsers/openai_tool_parser.py index a3cf793ed3a6d..db92ea8982d70 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/tool_parsers/openai_tool_parser.py @@ -12,10 +12,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger if TYPE_CHECKING: from vllm.tokenizers import TokenizerLike diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/tool_parsers/phi4mini_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py rename to vllm/tool_parsers/phi4mini_tool_parser.py index acb25ea2768e1..9003429d8c6f2 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/tool_parsers/phi4mini_tool_parser.py @@ -16,10 +16,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/tool_parsers/pythonic_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py rename to vllm/tool_parsers/pythonic_tool_parser.py index abeb923b93227..476a62d5f5273 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/tool_parsers/pythonic_tool_parser.py @@ -19,10 +19,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py rename to vllm/tool_parsers/qwen3coder_tool_parser.py index d49b14690ef03..d1a3cbeaafc7d 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -18,11 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/tool_parsers/qwen3xml_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py rename to vllm/tool_parsers/qwen3xml_tool_parser.py index 03862ff432a5d..107f791654a1a 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/tool_parsers/qwen3xml_tool_parser.py @@ -19,11 +19,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/tool_parsers/seed_oss_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py rename to vllm/tool_parsers/seed_oss_tool_parser.py index c7947faad1923..206072e65b10f 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/tool_parsers/seed_oss_tool_parser.py @@ -21,11 +21,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/tool_parsers/step3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py rename to vllm/tool_parsers/step3_tool_parser.py index 9213d6859dd93..acd99bf56d0b6 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/tool_parsers/step3_tool_parser.py @@ -17,11 +17,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) from vllm.utils import random_uuid logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/tool_parsers/utils.py similarity index 100% rename from vllm/entrypoints/openai/tool_parsers/utils.py rename to vllm/tool_parsers/utils.py diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/tool_parsers/xlam_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py rename to vllm/tool_parsers/xlam_tool_parser.py index effd2bd08b42a..9c2b585fe9fdb 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/tool_parsers/xlam_tool_parser.py @@ -17,7 +17,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) from vllm.logger import init_logger diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fb88c62dc5b23..a11d37b4b2edf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -66,6 +66,7 @@ class LazyConfigDict(dict): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", + bagel="BagelConfig", chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32="DeepseekV3Config", @@ -617,6 +618,28 @@ def get_config( hf_overrides=hf_overrides_kw, **kwargs, ) + + # Patching defaults for GGUF models + if _is_gguf: + # Some models have different default values between GGUF and HF. + def apply_gguf_default(key: str, gguf_default: Any): + """ + Apply GGUF defaults unless explicitly configured. + + This function reads/writes external `config` and `config_dict`. + If the specified `key` is not in `config_dict` (i.e. not explicitly + configured and the default HF value is used), it updates the + corresponding `config` value to `gguf_default`. + """ + if key not in config_dict: + config.update({key: gguf_default}) + + # Apply architecture-specific GGUF defaults. + if config.model_type in {"qwen3_moe"}: + # Qwen3 MoE: norm_topk_prob is always true. + # Note that, this parameter is always false (HF default) on Qwen2 MoE. + apply_gguf_default("norm_topk_prob", True) + # Special architecture mapping check for GGUF models if _is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index e536ca8521325..54fe1b8d7b523 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -16,6 +16,7 @@ import importlib _CLASS_TO_MODULE: dict[str, str] = { "AfmoeConfig": "vllm.transformers_utils.configs.afmoe", + "BagelConfig": "vllm.transformers_utils.configs.bagel", "ChatGLMConfig": "vllm.transformers_utils.configs.chatglm", "DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2", "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", @@ -54,6 +55,7 @@ _CLASS_TO_MODULE: dict[str, str] = { __all__ = [ "AfmoeConfig", + "BagelConfig", "ChatGLMConfig", "DeepseekVLV2Config", "DeepseekV3Config", diff --git a/vllm/transformers_utils/configs/bagel.py b/vllm/transformers_utils/configs/bagel.py new file mode 100644 index 0000000000000..53347ef452138 --- /dev/null +++ b/vllm/transformers_utils/configs/bagel.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import PretrainedConfig, SiglipVisionConfig +from transformers.models.qwen2 import Qwen2Config + + +class BagelConfig(PretrainedConfig): + """Configuration class for BAGEL model.""" + + model_type = "bagel" + + def __init__( + self, + visual_gen: bool = True, + visual_und: bool = True, + llm_config: dict | Qwen2Config | None = None, + vit_config: dict | SiglipVisionConfig | None = None, + vae_config: dict | None = None, + latent_patch_size: int = 2, + max_latent_size: int = 32, + vit_max_num_patch_per_side: int = 70, + connector_act: str = "gelu_pytorch_tanh", + interpolate_pos: bool = False, + timestep_shift: float = 1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.visual_gen = visual_gen + self.visual_und = visual_und + + # Convert dict configs to proper config objects + if isinstance(llm_config, dict): + self.llm_config = Qwen2Config(**llm_config) + else: + self.llm_config = llm_config or Qwen2Config() + + if isinstance(vit_config, dict): + self.vit_config = SiglipVisionConfig(**vit_config) + else: + self.vit_config = vit_config or SiglipVisionConfig() + + self.vae_config = vae_config or {"z_channels": 16, "downsample": 8} + self.latent_patch_size = latent_patch_size + self.max_latent_size = max_latent_size + self.vit_max_num_patch_per_side = vit_max_num_patch_per_side + self.connector_act = connector_act + self.interpolate_pos = interpolate_pos + self.timestep_shift = timestep_shift + + @property + def hidden_size(self) -> int: + """Return the hidden size of the language model.""" + return self.llm_config.hidden_size diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index b49fdbe9ce776..af25dbe4ccdfe 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -8,6 +8,7 @@ reasons: - There is a need to override the existing processor to support vLLM. """ +from vllm.transformers_utils.processors.bagel import BagelProcessor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor @@ -15,6 +16,7 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor __all__ = [ + "BagelProcessor", "DeepseekVLV2Processor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", diff --git a/vllm/transformers_utils/processors/bagel.py b/vllm/transformers_utils/processors/bagel.py new file mode 100644 index 0000000000000..850e64f2fad1e --- /dev/null +++ b/vllm/transformers_utils/processors/bagel.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +"""BAGEL processor for image and text inputs.""" + +from transformers import AutoProcessor +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class BagelProcessor(ProcessorMixin): + """ + Constructs a BAGEL processor which wraps a + SigLIP image processor and a Qwen2 tokenizer. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __call__( + self, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, + images: ImageInput = None, + **kwargs, + ): + """ + Main method to prepare for the model one or several sequences(s) and image(s). + """ + if images is not None: + # Process images with the image processor + # Ensure return_tensors is set to "pt" for PyTorch tensors + image_kwargs = {**kwargs} + if "return_tensors" not in image_kwargs: + image_kwargs["return_tensors"] = "pt" + pixel_values = self.image_processor(images, **image_kwargs) + else: + pixel_values = None + + text_inputs = self.tokenizer(text, **kwargs) if text is not None else None + + if pixel_values is not None and text_inputs is not None: + text_inputs["pixel_values"] = pixel_values["pixel_values"] + return text_inputs + elif pixel_values is not None: + return pixel_values + else: + return text_inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's batch_decode. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's decode. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +AutoProcessor.register("BagelProcessor", BagelProcessor) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 8745e1d9dbbbc..90af573535d3b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -60,17 +60,17 @@ def __getattr__(name: str): return cached_tokenizer_from_config if name == "init_tokenizer_from_configs": - from vllm.tokenizers import init_tokenizer_from_config + from vllm.tokenizers import cached_tokenizer_from_config warnings.warn( "`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` " - "has been moved to `vllm.tokenizers.init_tokenizer_from_config`. " + "has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. " "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) - return init_tokenizer_from_config + return cached_tokenizer_from_config raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index a099fde1bdc45..3d4f8449ad3b6 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -38,7 +38,7 @@ class DeepGemmQuantScaleFMT(Enum): return DeepGemmQuantScaleFMT.FLOAT32 return ( DeepGemmQuantScaleFMT.UE8M0 - if current_platform.is_device_capability(100) + if current_platform.is_device_capability_family(100) else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 ) @@ -50,7 +50,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100) + or current_platform.is_device_capability_family(100) ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @@ -381,22 +381,6 @@ def should_use_deepgemm_for_fp8_linear( ) -def should_use_deepgemm_for_fp8_linear_for_nk( - output_dtype: torch.dtype, - shape0: int, - shape1: int, - supports_deep_gemm: bool | None = None, -): - if supports_deep_gemm is None: - supports_deep_gemm = is_deep_gemm_supported() - return ( - supports_deep_gemm - and output_dtype == torch.bfloat16 - and shape0 % 128 == 0 - and shape1 % 128 == 0 - ) - - __all__ = [ "calc_diff", "DeepGemmQuantScaleFMT", @@ -411,7 +395,6 @@ __all__ = [ "is_deep_gemm_supported", "get_num_sms", "should_use_deepgemm_for_fp8_linear", - "should_use_deepgemm_for_fp8_linear_for_nk", "get_col_major_tma_aligned_tensor", "get_mk_alignment_for_contiguous_layout", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 9a66049350cd8..1c2710be3173b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool: ) +@functools.cache +def has_flashinfer_trtllm_fused_moe() -> bool: + """Return `True` if FlashInfer TRTLLM fused MoE is available.""" + if not has_flashinfer_moe(): + return False + required_functions = [ + ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ] + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return `True` if FlashInfer CUTLASS fused MoE is available.""" @@ -264,7 +281,9 @@ def supports_trtllm_attention() -> bool: return False # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability(100) and has_nvidia_artifactory() + return ( + current_platform.is_device_capability_family(100) and has_nvidia_artifactory() + ) def force_use_trtllm_attention() -> bool | None: diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index edcb79fbc9cd7..c97efce312b56 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -194,33 +194,12 @@ def get_kv_cache_torch_dtype( return torch_dtype -def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None: - quant_method = quant_cfg.get("quant_method", "") - if quant_method.startswith("modelopt"): - quantization_inner = quant_cfg.get("quantization", quant_cfg) - # Check if quant config is specified and use kv cache quant algo - kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get( - "kv_cache_quant_algo" - ) - if isinstance(kv_algo, str): - return STR_DTYPE_TO_TORCH_DTYPE[kv_algo.lower()] - return None - - def kv_cache_dtype_str_to_dtype( kv_cache_dtype: str, model_config: ModelConfig ) -> torch.dtype: - # Model config may not be specified for unit tests, default to float16 - dtype = model_config.dtype if model_config else torch.half if kv_cache_dtype == "auto": - hf_cfg = getattr(model_config, "hf_config", None) - if hf_cfg is not None: - quant_cfg = getattr(hf_cfg, "quantization_config", None) - if quant_cfg is not None: - kv_algo_dtype = get_kv_cache_quant_algo_dtype(quant_cfg) - return kv_algo_dtype if kv_algo_dtype is not None else dtype - return dtype - + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f5ad98cf2125c..3445e998d6371 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" +import copy from dataclasses import dataclass from typing import ClassVar @@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH ) + supports_update_block_table: bool = True def __init__( self, @@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ) return attn_metadata + def update_block_table( + self, + metadata: FlashAttentionMetadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> FlashAttentionMetadata: + new_metadata = copy.copy(metadata) + new_metadata.block_table = blk_table + new_metadata.slot_mapping = slot_mapping + return new_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4174b80ee312e..2740a6916fd97 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -564,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() - if self.head_dim == 256 and current_platform.is_device_capability(100): + if self.head_dim == 256 and current_platform.is_device_capability_family(100): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, ( diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 3a2f92d9921c3..ace2cbb0564c8 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_token_masks = torch.repeat_interleave( spec_sequence_masks, query_lens ) - index = torch.argsort(spec_token_masks) + index = torch.argsort(spec_token_masks, stable=True) num_non_spec_tokens = num_prefill_tokens + num_decode_tokens non_spec_token_indx = index[:num_non_spec_tokens] spec_token_indx = index[num_non_spec_tokens:] diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index bf1d8f09ab0ac..f923371283aa0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import itertools from dataclasses import dataclass @@ -134,6 +135,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] ): + supports_update_block_table: bool = True + def __init__( self, kv_cache_spec: AttentionSpec, @@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder( num_computed_tokens_p=num_computed_tokens_p, ) return attn_metadata + + def update_block_table( + self, + metadata: Mamba2AttentionMetadata, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> Mamba2AttentionMetadata: + new_metadata = copy.copy(metadata) + prefix_caching = self.vllm_config.cache_config.enable_prefix_caching + state_indices_t = blk_table if prefix_caching else blk_table[:, 0] + num_reqs = blk_table.shape[0] + + # For CUDA graphs, copy to persistent buffer + if ( + metadata.num_prefills == 0 + and num_reqs <= self.decode_cudagraph_max_bs + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + persistent_state_indices_t = self.state_indices_tensor[:num_reqs] + persistent_state_indices_t.copy_(state_indices_t, non_blocking=True) + state_indices_t = persistent_state_indices_t + + new_metadata.state_indices_tensor = state_indices_t + return new_metadata diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8265503c28c35..fea482493635f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool: and flashinfer_available and not vllm_config.attention_config.use_cudnn_prefill and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) @@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool: return ( flashinfer_available and vllm_config.attention_config.use_cudnn_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and has_nvidia_artifactory() ) @@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: return ( flashinfer_available and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index f3052fbaf2a65..0818078da0364 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -420,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) ) - if current_platform.is_device_capability(100): + if current_platform.is_device_capability_family(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( # TileSchedulerMetaDataSize = 8 @@ -719,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability(100) else 64 + self.padding = 128 if current_platform.is_device_capability_family(100) else 64 if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 00a0a77a1c2f7..589d6ef2f6348 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): qo_indptr: torch.Tensor | None = None # The dtype of MLA out tensor attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - _cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM def __init__( self, @@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_reqs, dtype=torch.int32, device=device ) - self.qo_indptr = torch.arange( - 0, max_num_reqs + 1, dtype=torch.int32, device=device + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device ) def _build_decode( @@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): seq_lens_device.cumsum(dim=0, dtype=torch.int32), ] ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] qo_indptr = self.qo_indptr[: 1 + num_reqs] else: @@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, ) @@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - # max_seqlen_qo must be 1 except for MTP - # TODO: Find the best value for MTP - max_seqlen_qo = 1 rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_qo_len, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index da43d87038234..56763f4b52539 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,6 +4,7 @@ import abc import enum import functools from abc import abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field, fields, make_dataclass from typing import ( TYPE_CHECKING, @@ -201,10 +202,11 @@ def _make_metadata_with_slice( ) # NOTE: last token can be outside of the last request if we have CG padding. - # If the "middle" request has tokens in both ubatches, we have to split it. - # If ubatch_slice is the first ubatch then we will be splitting the last - # request. If it's the second microbatch, then we will be splitting the - # first request + # If the request is split across ubatches, we have to adjust the metadata. + # splits_first_request: The first request in this slice is the continuation of + # a request that started in a previous slice. + # splits_last_request: The last request in this slice continues into the + # next slice. splits_first_request = first_tok > start_locs[first_req] splits_last_request = last_tok < start_locs[last_req + 1] - 1 @@ -225,7 +227,10 @@ def _make_metadata_with_slice( seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: - tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate + # the tokens skipped because query_start_loc_cpu might have been modified + # if splits_first_request is True. + tokens_skipped = start_locs[last_req + 1] - token_slice.stop query_start_loc[-1] -= tokens_skipped query_start_loc_cpu[-1] -= tokens_skipped @@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: int | None = None + # Does this backend/builder support updating the block table in existing + # metadata + supports_update_block_table: bool = False @abstractmethod def __init__( @@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError + def update_block_table( + self, + metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> M: + """ + Update the block table for the attention metadata. + Faster when theres multiple kv-cache groups that create virtually the + same metadata but just with different block tables. + + Only needs to be implemented if supports_update_block_table is True. + """ + raise NotImplementedError + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -599,7 +622,7 @@ def make_local_attention_virtual_batches( attn_chunk_size: int, common_attn_metadata: CommonAttentionMetadata, block_size: int = 0, -) -> CommonAttentionMetadata: +) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]: query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() block_table = common_attn_metadata.block_table_tensor @@ -711,9 +734,12 @@ def make_local_attention_virtual_batches( # tensor first, which recovers perf. batch_indices_torch = torch.from_numpy(batch_indices) block_indices_torch = torch.from_numpy(block_indices) - block_table_local = block_table[batch_indices_torch, block_indices_torch].view( - virtual_batches, -1 - ) + + # Save as a lambda so we can return this for update_block_table + make_block_table = lambda block_table: block_table[ + batch_indices_torch, block_indices_torch + ].view(virtual_batches, -1) + block_table_local = make_block_table(block_table) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -732,7 +758,7 @@ def make_local_attention_virtual_batches( causal=True, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), - ) + ), make_block_table def make_kv_sharing_fast_prefill_common_attn_metadata( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a9ce6e63cc775..754e0b9d08316 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface): if self.is_encoder_decoder else EncoderCacheManager(cache_size=encoder_cache_size) ) + # For encoder-decoder models, allocate the maximum number of tokens for Cross + # Attn blocks, as for Whisper its input is always padded to the maximum length. + # TODO (NickLucche): Generalize to models with variable-length encoder inputs. + self._num_encoder_max_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config) + ) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface): 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens ) - # Determine if we need to allocate cross-attention blocks. - if self.is_encoder_decoder and request.has_encoder_inputs: - # TODO(russellb): For Whisper, we know that the input is - # always padded to the maximum length. If we support other - # encoder-decoder models, this will need to be updated if we - # want to only allocate what is needed. - num_encoder_tokens = ( - self.scheduler_config.max_num_encoder_input_tokens - ) - else: - num_encoder_tokens = 0 + num_encoder_tokens = ( + self._num_encoder_max_input_tokens + if self.is_encoder_decoder and request.has_encoder_inputs + else 0 + ) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -1117,6 +1117,7 @@ class Scheduler(SchedulerInterface): stopped = False new_logprobs = None new_token_ids = generated_token_ids + pooler_output = pooler_outputs[req_index] if pooler_outputs else None kv_transfer_params = None status_before_stop = request.status @@ -1125,12 +1126,10 @@ class Scheduler(SchedulerInterface): new_token_ids, stopped = self._update_request_with_output( request, new_token_ids ) - - # Stop checking for pooler models. - pooler_output = None - if pooler_outputs: - pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, pooler_output) + elif request.pooling_params and pooler_output is not None: + # Pooling stops as soon as there is output. + request.status = RequestStatus.FINISHED_STOPPED + stopped = True if stopped: kv_transfer_params = self._free_request(request) diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 82166dc978396..6319731883225 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import torch - from vllm.v1.request import Request, RequestStatus @@ -39,14 +37,8 @@ def remove_all(lst: list, items_to_remove: set) -> list: return [item for item in lst if item not in items_to_remove] -def check_stop( - request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None -) -> bool: - if request.pooling_params: - if pooler_output is not None: - request.status = RequestStatus.FINISHED_STOPPED - return True - return False +def check_stop(request: Request, max_model_len: int) -> bool: + assert not request.pooling_params sampling_params = request.sampling_params assert sampling_params is not None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 8eff61563ccea..a6ee241c41151 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,7 +26,7 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -111,7 +111,7 @@ class AsyncLLM(EngineClient): if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_config(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index a3c18464d3f52..65e0c845b0afa 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -19,7 +19,8 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats @@ -188,29 +189,39 @@ class InputProcessor: def _validate_single_prompt(single_prompt: dict | str) -> None: if not isinstance(single_prompt, dict): return + mm_data = single_prompt.get("multi_modal_data") mm_uuids = single_prompt.get("multi_modal_uuids") if not mm_data or not mm_uuids: return + import torch + + def _get_len(items: object): + if isinstance(items, dict): # Embedding inputs + return _get_len(next(iter(items.values()))) if items else 1 + + if isinstance(items, list): + return len(items) + if isinstance(items, torch.Tensor): + # To keep backwards compatibility for single item embedding input + return 1 if getattr(items, "_is_single_item", False) else len(items) + + return 1 + for modality, items in mm_data.items(): if modality in mm_uuids: - data_len = len(items) if isinstance(items, list) else 1 - uuid_len = ( - len(mm_uuids[modality]) - if isinstance(mm_uuids[modality], list) - else 1 - ) + data_len = _get_len(items) + uuid_len = _get_len(mm_uuids[modality]) if uuid_len != data_len: raise ValueError( - f"multi_modal_uuids for modality '{modality}' " + f"multi_modal_uuids for modality {modality!r} " "must have same length as data: got " - f"{uuid_len} uuids vs " - f"{data_len} items." + f"{uuid_len} uuids vs {data_len} items." ) else: raise ValueError( - f"multi_modal_uuids for modality '{modality}' must " + f"multi_modal_uuids for modality {modality!r} must " "be provided if multi_modal_data is provided." ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4422eced82fea..1011317b706d3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,7 +23,7 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest @@ -86,7 +86,7 @@ class LLMEngine: if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_config(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9be3f4da7352d..8f7d8a71f1a2e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,6 +8,7 @@ from typing import Any, cast import torch +from vllm.lora.request import LoRARequest from vllm.outputs import ( CompletionOutput, PoolingOutput, @@ -93,7 +94,7 @@ class RequestState: request_id: str, parent_req: ParentRequest | None, request_index: int, - lora_name: str | None, + lora_request: LoRARequest | None, output_kind: RequestOutputKind, prompt: str | None, prompt_token_ids: list[int] | None, @@ -112,7 +113,8 @@ class RequestState: self.request_id = request_id self.parent_req = parent_req self.request_index = request_index - self.lora_name = lora_name + self.lora_request = lora_request + self.lora_name = lora_request.lora_name if lora_request is not None else None self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -178,9 +180,7 @@ class RequestState: request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=( - request.lora_request.name if request.lora_request is not None else None - ), + lora_request=request.lora_request, output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, @@ -289,6 +289,7 @@ class RequestState: return RequestOutput( request_id=request_id, + lora_request=self.lora_request, prompt=self.prompt, prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b42d026a3e15b..649875fe8b7c1 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -124,9 +124,7 @@ class MultiprocExecutor(Executor): # Set multiprocessing envs set_multiprocessing_worker_envs() - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # get_loopback_ip() for communication. + # use the loopback address get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( get_loopback_ip(), get_open_port() ) @@ -708,7 +706,7 @@ class WorkerProc: death_pipe.recv() except EOFError: # Parent process has exited, terminate this worker - logger.info("Parent process exited, terminating worker") + logger.info_once("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown shutdown_event.set() except Exception as e: diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index 2f2e85c0ff332..e1cf7b14a785c 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -13,7 +13,7 @@ from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.spec import OffloadingSpec -from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers from vllm.v1.kv_offload.worker.worker import OffloadingHandler @@ -32,7 +32,7 @@ class CPUOffloadingSpec(OffloadingSpec): self._manager: OffloadingManager | None = None # worker-side - self._handler: OffloadingHandler | None = None + self._handlers: CpuGpuOffloadingHandlers | None = None self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru") @@ -67,13 +67,13 @@ class CPUOffloadingSpec(OffloadingSpec): kv_caches: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: - if not self._handler: + if not self._handlers: if not current_platform.is_cuda_alike(): raise Exception( "CPU Offloading is currently only supported on CUDA-alike GPUs" ) - self._handler = CpuGpuOffloadingHandler( + self._handlers = CpuGpuOffloadingHandlers( attn_backends=attn_backends, gpu_block_size=self.gpu_block_size, cpu_block_size=self.offloaded_block_size, @@ -81,6 +81,6 @@ class CPUOffloadingSpec(OffloadingSpec): gpu_caches=kv_caches, ) - assert self._handler is not None - yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler - yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler + assert self._handlers is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index 461458c1f6ce8..42ae4f1413ad0 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import deque import numpy as np import torch @@ -8,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, TransferResult, @@ -51,7 +52,123 @@ def expand_block_ids( output_idx = output_end_idx -class CpuGpuOffloadingHandler(OffloadingHandler): +class SingleDirectionOffloadingHandler(OffloadingHandler): + """ + SingleDirectionOffloadingHandler handles transfers for a single direction, + either CPU->GPU or GPU->CPU. + Transfers are guaranteed to be executed in order of their submission. + Each transfer uses a unique CUDA stream, and its stream will start + executing only after the streams of previous transfers have finished. + """ + + def __init__( + self, + src_tensors: list[torch.Tensor], + dst_tensors: list[torch.Tensor], + kv_dim_before_num_blocks: list[bool], + src_block_size_factor: int, + dst_block_size_factor: int, + priority: int, + ): + """ + Initialize a SingleDirectionOffloadingHandler. + + Args: + src_tensors: list of KV cache tensors to copy from. + dst_tensors: list of KV cache tensors to copy to. + Order should match src_tensors. + kv_dim_before_num_blocks: list of bools, indicating + whether the respective KV cache tensor has a KV + dimension before its num_blocks dimension. + e.g. (2, num_blocks, ...) + src_block_size_factor: The number of kernel blocks + per KV block in a source tensor. + dst_block_size_factor: The number of kernel blocks + per KV block in a destination tensor. + priority: The priority of the backing CUDA streams. + Lower numbers indicate higher priority. + """ + assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks) + + self.src_tensors: list[torch.Tensor] = src_tensors + self.dst_tensors: list[torch.Tensor] = dst_tensors + self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks + self.src_block_size_factor: int = src_block_size_factor + self.dst_block_size_factor: int = dst_block_size_factor + self.priority = priority + + # queue of transfers (job_id, stream, event) + self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() + # list of CUDA streams available for re-use + self._stream_pool: list[torch.cuda.Stream] = [] + # list of CUDA events available for re-use + self._event_pool: list[torch.Event] = [] + + def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: + src_spec, dst_spec = transfer_spec + assert isinstance(src_spec, BlockIDsLoadStoreSpec) + assert isinstance(dst_spec, BlockIDsLoadStoreSpec) + + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + src_sub_block_count = src_blocks.size * self.src_block_size_factor + dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor + src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor + + assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip + + src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) + expand_block_ids( + src_blocks, + self.src_block_size_factor, + src_to_dst[:, 0], + skip_count=src_sub_blocks_to_skip, + ) + expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) + src_to_dst_tensor = torch.from_numpy(src_to_dst) + + stream = ( + self._stream_pool.pop() + if self._stream_pool + else torch.cuda.Stream(priority=self.priority) + ) + event = self._event_pool.pop() if self._event_pool else torch.Event() + if self._transfers: + _, _, last_event = self._transfers[-1] + # assure job will start only after the previous one completes + stream.wait_event(last_event) + with torch.cuda.stream(stream): + for src_tensor, dst_tensor, kv_dim in zip( + self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks + ): + if kv_dim: + src_key_cache, src_value_cache = src_tensor + dst_key_cache, dst_value_cache = dst_tensor + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) + else: + ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) + event.record(stream) + + self._transfers.append((job_id, stream, event)) + + # success + return True + + def get_finished(self) -> list[TransferResult]: + results: list[TransferResult] = [] + while self._transfers and self._transfers[0][2].query(): + job_id, stream, event = self._transfers.popleft() + results.append((job_id, True)) + self._stream_pool.append(stream) + self._event_pool.append(event) + return results + + +class CpuGpuOffloadingHandlers: def __init__( self, gpu_block_size: int, @@ -60,27 +177,20 @@ class CpuGpuOffloadingHandler(OffloadingHandler): gpu_caches: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], ): + assert gpu_caches assert cpu_block_size % gpu_block_size == 0 - self.block_size_factor = cpu_block_size // gpu_block_size - - # cuda streams for gpu->cpu and cpu->gpu - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() - - # job_id -> transfer cuda event - self.transfer_events: dict[int, torch.Event] = {} - # list of cuda events available for re-use - self.events_pool: list[torch.Event] = [] + block_size_factor = cpu_block_size // gpu_block_size pin_memory = is_pin_memory_available() # allocate cpu tensors logger.info("Allocating %d CPU tensors...", len(gpu_caches)) - self.gpu_tensors: list[torch.Tensor] = [] - self.cpu_tensors: list[torch.Tensor] = [] - self.kv_dim_before_num_blocks: list[bool] = [] + gpu_tensors: list[torch.Tensor] = [] + cpu_tensors: list[torch.Tensor] = [] + kv_dim_before_num_blocks: list[bool] = [] + kernel_block_size: int | None = None for layer_name, gpu_tensor in gpu_caches.items(): - self.gpu_tensors.append(gpu_tensor) + gpu_tensors.append(gpu_tensor) gpu_shape = gpu_tensor.shape attn_backend = attn_backends[layer_name] @@ -88,16 +198,21 @@ class CpuGpuOffloadingHandler(OffloadingHandler): num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) + has_layers_dim = False if len(gpu_shape) != len(test_shape): # cross-layers tensor # shape is (num_blocks, ...) assert len(gpu_shape) == len(test_shape) + 1 num_blocks_idx = 0 - self.kv_dim_before_num_blocks.append(False) + has_layers_dim = True + kv_dim_before_num_blocks.append(False) + + # prepend a dummy num_layers=80 to test_shape + test_shape = (80,) + test_shape elif test_shape[0] == 1234: # shape is (num_blocks, ...) num_blocks_idx = 0 - self.kv_dim_before_num_blocks.append(False) + kv_dim_before_num_blocks.append(False) else: # shape should be (2, num_blocks, ...) assert test_shape[0] == 2 @@ -105,13 +220,32 @@ class CpuGpuOffloadingHandler(OffloadingHandler): assert gpu_shape[0] == 2 num_blocks_idx = 1 - self.kv_dim_before_num_blocks.append(True) + kv_dim_before_num_blocks.append(True) + + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=has_layers_dim + ) + assert len(kv_cache_stride_order) == len(gpu_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(gpu_shape))) + + # permute test_shape according to stride_order + test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) + + # find block_size (16) dimension index + block_size_idx = test_shape.index(16) + if kernel_block_size is not None: + assert kernel_block_size == gpu_shape[block_size_idx] + else: + kernel_block_size = gpu_shape[block_size_idx] + assert gpu_block_size % kernel_block_size == 0 cpu_shape = list(gpu_shape) - cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor + cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor logger.debug("Allocating CPU tensor of shape %r", cpu_shape) - self.cpu_tensors.append( + cpu_tensors.append( torch.zeros( cpu_shape, dtype=gpu_tensor.dtype, @@ -120,72 +254,27 @@ class CpuGpuOffloadingHandler(OffloadingHandler): ) ) - def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: - src_spec, dst_spec = spec - if isinstance(src_spec, CPULoadStoreSpec): - assert isinstance(dst_spec, GPULoadStoreSpec) - stream = self.h2d_stream - src_tensors = self.cpu_tensors - dst_tensors = self.gpu_tensors - src_block_size_factor = self.block_size_factor - dst_block_size_factor = 1 - else: - assert isinstance(src_spec, GPULoadStoreSpec) - assert isinstance(dst_spec, CPULoadStoreSpec) - stream = self.d2h_stream - src_tensors = self.gpu_tensors - dst_tensors = self.cpu_tensors - src_block_size_factor = 1 - dst_block_size_factor = self.block_size_factor + assert kernel_block_size is not None + gpu_block_size_factor = gpu_block_size // kernel_block_size + cpu_block_size_factor = cpu_block_size // kernel_block_size - src_blocks = src_spec.block_ids - dst_blocks = dst_spec.block_ids - assert src_blocks.ndim == 1 - assert dst_blocks.ndim == 1 + # TODO (orozery): adapt swap_blocks to support gpu_block_size_factor + assert gpu_block_size_factor == 1 - src_sub_block_count = src_blocks.size * src_block_size_factor - dst_sub_block_count = dst_blocks.size * dst_block_size_factor - src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor - - assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip - - src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) - expand_block_ids( - src_blocks, - src_block_size_factor, - src_to_dst[:, 0], - skip_count=src_sub_blocks_to_skip, + self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler( + src_tensors=gpu_tensors, + dst_tensors=cpu_tensors, + kv_dim_before_num_blocks=kv_dim_before_num_blocks, + src_block_size_factor=gpu_block_size_factor, + dst_block_size_factor=cpu_block_size_factor, + priority=1, ) - expand_block_ids(dst_blocks, dst_block_size_factor, src_to_dst[:, 1]) - src_to_dst_tensor = torch.from_numpy(src_to_dst) - event = self.events_pool.pop() if self.events_pool else torch.Event() - with torch.cuda.stream(stream): - for src_tensor, dst_tensor, kv_dim in zip( - src_tensors, dst_tensors, self.kv_dim_before_num_blocks - ): - if kv_dim: - src_key_cache = src_tensor[0] - dst_key_cache = dst_tensor[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) - src_value_cache = src_tensor[1] - dst_value_cache = dst_tensor[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) - else: - ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) - event.record(stream) - - self.transfer_events[job_id] = event - - # success - return True - - def get_finished(self) -> list[TransferResult]: - results: list[TransferResult] = [] - for job_id, event in self.transfer_events.items(): - if event.query(): - results.append((job_id, True)) - self.events_pool.append(event) - for job_id, _ in results: - del self.transfer_events[job_id] - return results + self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler( + src_tensors=cpu_tensors, + dst_tensors=gpu_tensors, + kv_dim_before_num_blocks=kv_dim_before_num_blocks, + src_block_size_factor=cpu_block_size_factor, + dst_block_size_factor=gpu_block_size_factor, + priority=-1, + ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4dd478804049b..79ee4161e9dfa 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import ( @@ -71,7 +71,7 @@ class StructuredOutputManager: # of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_config( + self.tokenizer = cached_tokenizer_from_config( model_config=self.vllm_config.model_config ) reasoning_parser = ( diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 826ee08caa4e2..9dd506880389a 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -10,7 +10,8 @@ import torch import vllm.envs from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer +from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, @@ -267,13 +268,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Unsupported keywords for objects if obj.get("type") == "object" and any( - key in obj - for key in ( - "minProperties", - "maxProperties", - "propertyNames", - "patternProperties", - ) + key in obj for key in ("patternProperties", "propertyNames") ): return True diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index ae42b33f80f88..cb5ad99cfbdf7 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput if TYPE_CHECKING: import outlines_core as oc import transformers.file_utils as file_utils - import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 import xgrammar as xgr + from transformers.convert_slow_tokenizer import bytes_to_unicode from vllm.tokenizers import TokenizerLike from vllm.v1.worker.gpu_input_batch import InputBatch @@ -30,10 +30,8 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") - tokenization_gpt2 = LazyLoader( - "tokenization_gpt2", - globals(), - "transformers.models.gpt2.tokenization_gpt2", + bytes_to_unicode = LazyLoader( + "bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer" ) TokenizerLike = object @@ -204,7 +202,7 @@ def _reduced_vocabulary( A Dict of token string -> equivalent token ids """ - unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} + unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: string = tokenizer.convert_tokens_to_string([token]) diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 1b9646e1980a8..82de0cba9194b 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.v1.worker.ubatch_utils import ( check_ubatch_thresholds, - is_second_ubatch_empty, + is_last_ubatch_empty, ) logger = init_logger(__name__) @@ -56,7 +56,7 @@ def _run_ar( return tensor -def _post_process_ubatch(tensor: torch.Tensor) -> bool: +def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool: orig_num_tokens_tensor = tensor[0, :] padded_num_tokens_tensor = tensor[1, :] @@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: # there are no "empty" second ubatches orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) - if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches): logger.debug( "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens ) @@ -146,7 +146,7 @@ def _synchronize_dp_ranks( assert should_attempt_dp_padding == should_dp_pad # Check conditions for microbatching - should_ubatch = _post_process_ubatch(tensor) + should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches) if should_ubatch and not should_dp_pad: logger.debug_once( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 978224faae65e..179f713c4d86a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1630,6 +1630,15 @@ class GPUModelRunner( logits_indices ) + # Cache attention metadata builds across hybrid KV-cache groups + # The only thing that changes between different hybrid KV-cache groups when the + # same metadata builder and KVCacheSpec is the same is the block table, so we + # can cache the attention metadata builds and just update the block table using + # `builder.update_block_table` if the builder supports it. + cached_attn_metadata: dict[ + tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata + ] = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1637,13 +1646,15 @@ class GPUModelRunner( ubid: int | None = None, ) -> None: attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder(ubid or 0) + cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder)) + cascade_attn_prefix_len = ( cascade_attn_prefix_lens[kv_cache_gid][attn_gid] if cascade_attn_prefix_lens else 0 ) - builder = attn_group.get_metadata_builder(ubid or 0) extra_attn_metadata_args = {} if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): assert ubid is None, "UBatching not supported with GDN yet" @@ -1658,12 +1669,23 @@ class GPUModelRunner( attn_metadata_i = builder.build_for_cudagraph_capture( common_attn_metadata ) + elif ( + cache_key in cached_attn_metadata + and builder.supports_update_block_table + ): + attn_metadata_i = builder.update_block_table( + cached_attn_metadata[cache_key], + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping, + ) else: attn_metadata_i = builder.build( common_prefix_len=cascade_attn_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + if builder.supports_update_block_table: + cached_attn_metadata[cache_key] = attn_metadata_i if ubid is None: assert isinstance(attn_metadata, dict) @@ -2987,7 +3009,7 @@ class GPUModelRunner( cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: # Pre-compute cascade attention prefix lengths cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, @@ -3028,6 +3050,13 @@ class GPUModelRunner( num_scheduled_tokens_np, num_tokens_padded, num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) pad_attn = cudagraph_mode == CUDAGraphMode.FULL @@ -3710,11 +3739,14 @@ class GPUModelRunner( # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching + ): self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - elif self.parallel_config.enable_dbo: + elif self.parallel_config.use_ubatching: if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device @@ -4095,7 +4127,16 @@ class GPUModelRunner( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) attn_metadata: PerLayerAttnMetadata | None = None @@ -4644,7 +4685,7 @@ class GPUModelRunner( # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph allow_microbatching = ( - self.parallel_config.enable_dbo + self.parallel_config.use_ubatching and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode and check_ubatch_thresholds( @@ -4779,8 +4820,8 @@ class GPUModelRunner( if kv_cache_group_id < len(kernel_block_sizes) else None, num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, + if not self.parallel_config.use_ubatching + else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) # Note (tdoublep): do this *after* constructing builders, diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 2ce2b64512560..af09129e67b1e 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -103,8 +103,10 @@ class UBatchWrapper: self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.comm_stream = torch.cuda.Stream(device=device) - # Two ubatch threads plus the main thread - self.ready_barrier = threading.Barrier(3) + # Ubatch threads plus the main thread + self.ready_barrier = threading.Barrier( + self.vllm_config.parallel_config.num_ubatches + 1 + ) self.cudagraphs: dict[int, CUDAGraphMetaData] = {} @@ -309,7 +311,7 @@ class UBatchWrapper: create_forward_context( attn_metadata[i] if attn_metadata is not None else None, self.vllm_config, - dp_metadata=dp_metadata, + dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, ) @@ -417,18 +419,19 @@ class UBatchWrapper: # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None - num_tokens_per_ubatch = ( - ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start - ) - dp_size = self.vllm_config.parallel_config.data_parallel_size - ubatch_num_tokens_across_dp = torch.tensor( - [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 - ) - ubatch_dp_metadata = DPMetadata.make( - self.vllm_config.parallel_config, - num_tokens_per_ubatch, - ubatch_num_tokens_across_dp, - ) + ubatch_dp_metadata = [] + for ubatch_slice in ubatch_slices: + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata.append( + DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + ) if ( num_tokens not in self.cudagraphs @@ -464,7 +467,7 @@ class UBatchWrapper: intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, compute_stream=compute_stream, - dp_metadata=dp_metadata, + dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 21a8564f83c40..1e13650cd083e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -931,10 +931,11 @@ def init_worker_distributed_environment( backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" + attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config from vllm.model_executor.layers.batch_invariant import init_batch_invariance - init_batch_invariance() + init_batch_invariance(attention_config.backend) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_method = distributed_init_method or "env://" diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 44788476fc9c5..f6889173578d6 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -27,14 +27,16 @@ class UBatchSlice: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: - return (padded_num_tokens // 2) >= orig_num_tokens +def is_last_ubatch_empty( + orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int +) -> bool: + return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens def check_ubatch_thresholds( config: ParallelConfig, num_tokens: int, uniform_decode: bool ) -> bool: - if not config.enable_dbo: + if not config.use_ubatching: return False if uniform_decode: return num_tokens >= config.dbo_decode_token_threshold @@ -42,21 +44,17 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold -# This just pads the second ubatch slice out to the total number of tokens +# This pads the last ubatch slice out to the total number of tokens # (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. def _pad_out_ubatch_slices( ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ) -> UBatchSlices: - # TODO(lucas): handle empty second ubatch - padded_second_request_slice = slice( - ubatch_slices[1].request_slice.start, num_reqs_padded - ) - padded_second_token_slice = slice( - ubatch_slices[1].token_slice.start, num_total_tokens - ) - return [ - ubatch_slices[0], - UBatchSlice(padded_second_request_slice, padded_second_token_slice), + last_slice = ubatch_slices[-1] + padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded) + padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens) + + return ubatch_slices[:-1] + [ + UBatchSlice(padded_last_request_slice, padded_last_token_slice) ] @@ -65,40 +63,45 @@ def maybe_create_ubatch_slices( num_scheduled_tokens: np.ndarray, num_tokens_padded: int, num_reqs_padded: int, - split_point: int | None = None, + num_ubatches: int, + split_point: list[int] | int | None = None, ) -> tuple[UBatchSlices | None, UBatchSlices | None]: if not should_ubatch: return None, None if split_point is None: - split_point = int(num_tokens_padded) // 2 + split_point = int(num_tokens_padded) // num_ubatches + + token_split_points = [split_point * i for i in range(1, num_ubatches)] # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) - first_ubatch_token_slice = slice(0, split_point) - second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + ubatch_slices = [] + start_token = 0 - # Determine request slices using exclusive stop semantics - # First ubatch includes requests whose tokens overlap [0, split_point) - first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left") - ) - first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + # Add the end point to the split points to make iteration easier + all_points = token_split_points + [cu_num_tokens[-1]] - # Second ubatch starts at the request that contains the split_point - # or the request starting exactly at split_point (if on boundary) - second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1 - ) - second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) + for end_token in all_points: + token_slice = slice(start_token, end_token) - ubatch_slices = [ - UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), - ] + # Determine request slices using exclusive stop semantics + # Ubatch includes requests whose tokens overlap [start_token, end_token) + + # Start at the request that contains the start_token + # or the request starting exactly at start_token (if on boundary) + req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1) + + # Stop at the request that starts at or after end_token + req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left")) + + req_slice = slice(req_start, req_stop) + ubatch_slices.append(UBatchSlice(req_slice, token_slice)) + + start_token = end_token ubatch_slices_padded = _pad_out_ubatch_slices( ubatch_slices, num_tokens_padded, num_reqs_padded diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index be8326e2fdbc1..e7a947f2ea8ca 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,10 +7,15 @@ import torch from vllm import forward_context from vllm.forward_context import ForwardContext +from vllm.logger import init_logger from vllm.utils.torch_utils import current_stream +logger = init_logger(__name__) + _THREAD_ID_TO_CONTEXT: dict = {} -_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] +# Here we hardcode the number of microbatches to 2 for default. +_NUM_UBATCHES: int = 2 +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [] class UBatchContext: @@ -48,6 +53,7 @@ class UBatchContext: global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id _CURRENT_CONTEXTS[self.id] = self + # _NUM_UBATCHES is set in make_ubatch_contexts self.ready_barrier.wait() self.cpu_wait_event.wait() @@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function( def dbo_register_recv_hook(recv_hook): if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] - next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] next_ctx.recv_hook = recv_hook @@ -202,7 +208,14 @@ def make_ubatch_contexts( ready_barrier: threading.Barrier, schedule: str = "default", ) -> list[UBatchContext]: - assert num_micro_batches == 2, "only been tested with 2 micro-batches" + global _NUM_UBATCHES, _CURRENT_CONTEXTS + assert num_micro_batches > 1, "num_micro_batches must be greater than 1" + + _NUM_UBATCHES = num_micro_batches + # Ensure the global context list is large enough + if len(_CURRENT_CONTEXTS) < num_micro_batches: + _CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS))) + """ Create a context manager for micro-batching synchronization. """ @@ -210,8 +223,6 @@ def make_ubatch_contexts( gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)] gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)] - assert len(forward_contexts) == 2 - ctxs = [] for i in range(num_micro_batches): ctx = UBatchContext( diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index a16dde1f67800..bbbd7705d54e4 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -145,12 +145,20 @@ class WorkspaceManager: for ubatch_id in range(self._num_ubatches): current_workspace = self._current_workspaces[ubatch_id] - if current_workspace is None: + if ( + current_workspace is None + or self._workspace_size_bytes(current_workspace) < required_bytes + ): + # Delete old tensor before allocating new one to avoid + # memory spike from resize_(). resize_() allocates new + # memory before freeing old, which can cause OOM. + # Must clear the list reference first since local var + # is just a copy of the reference. + self._current_workspaces[ubatch_id] = None + del current_workspace self._current_workspaces[ubatch_id] = torch.empty( (required_bytes,), dtype=torch.uint8, device=self._device ) - elif self._workspace_size_bytes(current_workspace) < required_bytes: - current_workspace.resize_(required_bytes) if envs.VLLM_DEBUG_WORKSPACE: logger.info(