Merge remote-tracking branch 'origin/main' into refactor-fp8-linear

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-12-11 04:51:04 +00:00
commit 52e2a31a95
131 changed files with 2893 additions and 1627 deletions

View File

@ -398,7 +398,8 @@ steps:
timeout_in_minutes: 25 timeout_in_minutes: 25
gpu: h100 gpu: h100
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/v1/attention
- vllm/model_executor/layers
- tests/v1/determinism/ - tests/v1/determinism/
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
@ -440,23 +441,29 @@ steps:
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
source_file_dependencies: source_file_dependencies:
- vllm/entrypoints - vllm/entrypoints
- vllm/multimodal
- examples/ - examples/
commands: commands:
- pip install tensorizer # for tensorizer test - pip install tensorizer # for tensorizer test
# for basic
- python3 offline_inference/basic/chat.py
- python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model facebook/opt-125m
- python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
- python3 offline_inference/basic/chat.py - python3 offline_inference/basic/classify.py
- python3 offline_inference/prefix_caching.py - python3 offline_inference/basic/embed.py
- python3 offline_inference/llm_engine_example.py - python3 offline_inference/basic/score.py
# for multi-modal models
- python3 offline_inference/audio_language.py --seed 0 - python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py # for pooling models
- python3 offline_inference/basic/embed.py - python3 pooling/pooling/vision_language_pooling.py --seed 0
- python3 offline_inference/basic/score.py # for features demo
- python3 offline_inference/prefix_caching.py
- python3 offline_inference/llm_engine_example.py
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
# https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536
@ -718,6 +725,18 @@ steps:
- uv pip install --system conch-triton-kernels - uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min
timeout_in_minutes: 75
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
autorun_on_main: true
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: OpenAI API correctness # 10min - label: OpenAI API correctness # 10min
timeout_in_minutes: 15 timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
@ -727,7 +746,7 @@ steps:
- csrc/ - csrc/
- vllm/entrypoints/openai/ - vllm/entrypoints/openai/
- vllm/model_executor/models/whisper.py - vllm/model_executor/models/whisper.py
commands: # LMEval commands: # LMEval+Transcription WER check
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
- pytest -s entrypoints/openai/correctness/ - pytest -s entrypoints/openai/correctness/
@ -963,6 +982,19 @@ steps:
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
- vllm/inputs/
- vllm/v1/core/
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: Multi-Modal Models Test (Extended) 1 # 60min - label: Multi-Modal Models Test (Extended) 1 # 60min
timeout_in_minutes: 120 timeout_in_minutes: 120
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
@ -1098,7 +1130,6 @@ steps:
- vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py - vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py - vllm/model_executor/layers/quantization/input_quant_fp8.py
- vllm/model_executor/layers/fused_moe/layer.py
- tests/compile/test_fusion_attn.py - tests/compile/test_fusion_attn.py
- tests/compile/test_silu_mul_quant_fusion.py - tests/compile/test_silu_mul_quant_fusion.py
- tests/compile/distributed/test_fusion_all_reduce.py - tests/compile/distributed/test_fusion_all_reduce.py
@ -1132,12 +1163,25 @@ steps:
- vllm/model_executor/layers/activation.py - vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py - vllm/model_executor/layers/quantization/input_quant_fp8.py
- tests/compile/distributed/test_fusions_e2e.py - tests/compile/distributed/test_fusions_e2e.py
- tests/compile/fullgraph/test_full_graph.py
commands: commands:
- nvidia-smi - nvidia-smi
# Run all e2e fusion tests # Run all e2e fusion tests
- pytest -v -s tests/compile/distributed/test_fusions_e2e.py - pytest -v -s tests/compile/distributed/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Blackwell Quantized MoE Test - label: Blackwell Quantized MoE Test
timeout_in_minutes: 60 timeout_in_minutes: 60
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
@ -1155,6 +1199,16 @@ steps:
commands: commands:
- pytest -s -v tests/quantization/test_blackwell_moe.py - pytest -s -v tests/quantization/test_blackwell_moe.py
- label: Blackwell LM Eval Small Models
timeout_in_minutes: 120
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
##### 1 GPU test ##### ##### 1 GPU test #####
##### multi gpus test ##### ##### multi gpus test #####
@ -1397,6 +1451,39 @@ steps:
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s -x lora/test_mixtral.py - pytest -v -s -x lora/test_mixtral.py
- label: LM Eval Large Models # optional
gpu: a100
optional: true
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H100 test #####
- label: LM Eval Large Models (H100) # optional
gpu: h100
optional: true
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
##### H200 test ##### ##### H200 test #####
- label: Distributed Tests (H200) # optional - label: Distributed Tests (H200) # optional
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
@ -1440,29 +1527,6 @@ steps:
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: Blackwell LM Eval Small Models
timeout_in_minutes: 120
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
- label: Multi-Modal Accuracy Eval (Small Models) # 10min
timeout_in_minutes: 70
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
- vllm/inputs/
- vllm/v1/core/
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: LM Eval Large Models (4 Card) - label: LM Eval Large Models (4 Card)
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4 agent_pool: mi325_4
@ -1478,21 +1542,6 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
- label: LM Eval Large Models (H100) # optional
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4
# grade: Blocking
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
- label: ROCm LM Eval Large Models (8 Card) - label: ROCm LM Eval Large Models (8 Card)
mirror_hardwares: [amdproduction] mirror_hardwares: [amdproduction]
agent_pool: mi325_8 agent_pool: mi325_8
@ -1517,6 +1566,20 @@ steps:
- uv pip install --system 'gpt-oss[eval]==0.0.5' - uv pip install --system 'gpt-oss[eval]==0.0.5'
- VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
##### RL Integration Tests #####
- label: Prime-RL Integration Test # 15min
mirror_hardwares: [amdexperimental]
agent_pool: mi325_2
# grade: Blocking
timeout_in_minutes: 30
optional: true
num_gpus: 2
working_dir: "/vllm-workspace"
source_file_dependencies:
- vllm/
- .buildkite/scripts/run-prime-rl-test.sh
commands:
- bash .buildkite/scripts/run-prime-rl-test.sh
- label: DeepSeek V2-Lite Accuracy - label: DeepSeek V2-Lite Accuracy
mirror_hardwares: [amdexperimental, amdproduction] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4 agent_pool: mi325_4
@ -1550,17 +1613,26 @@ steps:
commands: commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
##### RL Integration Tests ##### - label: DeepSeek V2-Lite Async EPLB Accuracy
- label: Prime-RL Integration Test # 15min timeout_in_minutes: 60
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
agent_pool: mi325_2 agent_pool: mi325_4
# grade: Blocking # grade: Blocking
timeout_in_minutes: 30 gpu: h100
optional: true optional: true
num_gpus: 2 num_gpus: 4
working_dir: "/vllm-workspace" working_dir: "/vllm-workspace"
source_file_dependencies:
- vllm/
- .buildkite/scripts/run-prime-rl-test.sh
commands: commands:
- bash .buildkite/scripts/run-prime-rl-test.sh - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040

View File

@ -468,7 +468,9 @@ steps:
# tests covered elsewhere. # tests covered elsewhere.
# Use `find` to launch multiple instances of pytest so that # Use `find` to launch multiple instances of pytest so that
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965 # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" # However, find does not normally propagate error codes, so we combine it with xargs
# (using -0 for proper path handling)
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
- label: PyTorch Fullgraph Smoke Test # 15min - label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30 timeout_in_minutes: 30
@ -482,7 +484,9 @@ steps:
# as it is a heavy test that is covered in other steps. # as it is a heavy test that is covered in other steps.
# Use `find` to launch multiple instances of pytest so that # Use `find` to launch multiple instances of pytest so that
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965 # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
- "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;" # However, find does not normally propagate error codes, so we combine it with xargs
# (using -0 for proper path handling)
- "find compile/fullgraph -maxdepth 1 -name 'test_*.py' -not -name 'test_full_graph.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
- label: PyTorch Fullgraph Test # 27min - label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40 timeout_in_minutes: 40

View File

@ -13,7 +13,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Python - name: Set up Python
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0

View File

@ -12,7 +12,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6.0.1
- uses: astral-sh/setup-uv@v7 - uses: astral-sh/setup-uv@v7
with: with:

View File

@ -16,7 +16,7 @@ jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
with: with:
python-version: "3.12" python-version: "3.12"

View File

@ -15,7 +15,7 @@ jobs:
actions: write actions: write
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with: with:
# Increasing this value ensures that changes to this workflow # Increasing this value ensures that changes to this workflow
# propagate to all issues and PRs in days rather than months # propagate to all issues and PRs in days rather than months

View File

@ -96,8 +96,9 @@ start_server() {
# This correctly passes each element as a separate argument. # This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled # Start server with profiling enabled
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & VLLM_SERVER_DEV_MODE=1 \
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else else
# Start server without profiling # Start server without profiling
VLLM_SERVER_DEV_MODE=1 \ VLLM_SERVER_DEV_MODE=1 \

View File

@ -963,8 +963,7 @@ def create_argument_parser():
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
help="Use Torch Profiler. The endpoint must be launched with " help="Use vLLM Profiling. --profiler-config must be provided on the server.",
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--result-dir", "--result-dir",

View File

@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
endif() endif()
# Build ACL with CMake # Build ACL with CMake
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
set(CMAKE_BUILD_TYPE "Release")
set(ARM_COMPUTE_ARCH "armv8.2-a")
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
set(ARM_COMPUTE_BUILD_TESTING "OFF")
set(_cmake_config_cmd set(_cmake_config_cmd
${CMAKE_COMMAND} -G Ninja -B build ${CMAKE_COMMAND} -G Ninja -B build
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF -DARM_COMPUTE_BUILD_SHARED_LIB=OFF

View File

@ -186,7 +186,7 @@ struct AttentionMetadata {
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2 // - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float) // * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains: // Reduction scratchpad contains:
// - flags: bool array to indicate wether the split is finished // - flags: bool array to indicate whether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size // - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4 // - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad { class AttentionScratchPad {

View File

@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
// Same as upstream, should be kept the same when possible, not formatted for // Same as upstream, should be kept the same when possible, not formatted for
// easier comparison // easier comparison
// with `SwapAB ? N : M -> M` since we dont support SwapAB // with `SwapAB ? N : M -> M` since we don't support SwapAB
// clang-format off // clang-format off
template<class ProblemShape> template<class ProblemShape>
static bool static bool

View File

@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
} }
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support #endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) { int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2; int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd; int rnds[13];
nPrRnd -= div1 * 3; for (int i = 0; i < 13; i++) {
int rnds3 = N / nPrRnd; rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1; nPrRnd -= div1;
int rnds4 = N / nPrRnd; }
nPrRnd -= div1; for (int i = 12; i >= 0; i--)
int rnds5 = N / nPrRnd; if (rnds[0] == rnds[i]) return (div2 - i);
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
} }
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2; const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ #define WVSPLITK(_YTILE, _UNRL, _N) \
_N) \ { \
{ \ dim3 block(64, 16); \
dim3 block(64, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ biasf4, c, __wvPrGrp, CuCount); \
biasf4, c, __wvPrGrp, CuCount); \ else if (K_in * N_in <= max_lds_len * 1.2) \
} else if (K_in * N_in <= max_lds_len * 1.2) { \ wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ biasf4, c, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ else \
biasf4, c, __wvPrGrp, CuCount); \ wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
} else { \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ biasf4, c, __wvPrGrp, CuCount); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ }
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \ #define WVSPLIT_TILE(_sYT, __N) \
} \ { \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
} }
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr()) ? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr; : nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr()); fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) { switch (N_in) {
case 1: case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) WVSPLIT_TILE(sYT, 1)
break; break;
case 2: case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) WVSPLIT_TILE(sYT, 2)
break; break;
case 3: case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) WVSPLIT_TILE(sYT, 3)
break; break;
case 4: case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) WVSPLIT_TILE(sYT, 4)
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(

View File

@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
- [vllm.config.MultiModalConfig][] - [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][] - [vllm.config.PoolerConfig][]
- [vllm.config.StructuredOutputsConfig][] - [vllm.config.StructuredOutputsConfig][]
- [vllm.config.ProfilerConfig][]
- [vllm.config.ObservabilityConfig][] - [vllm.config.ObservabilityConfig][]
- [vllm.config.KVTransferConfig][] - [vllm.config.KVTransferConfig][]
- [vllm.config.CompilationConfig][] - [vllm.config.CompilationConfig][]

View File

@ -5,16 +5,15 @@
## Profile with PyTorch Profiler ## Profile with PyTorch Profiler
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables: We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default - `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default - `torch_profiler_with_memory` to record memory, off by default
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default - `torch_profiler_with_stack` to enable recording stack information, on by default
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default - `torch_profiler_with_flops` to enable recording FLOPs, off by default
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default - `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default - `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag. When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
#### OpenAI Server #### OpenAI Server
```bash ```bash
VLLM_TORCH_PROFILER_DIR=./vllm_profile \ vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
vllm serve meta-llama/Llama-3.1-8B-Instruct
``` ```
vllm bench command: vllm bench command:
@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
```bash ```bash
# server # server
VLLM_TORCH_CUDA_PROFILE=1 \
nsys profile \ nsys profile \
--trace-fork-before-exec=true \ --trace-fork-before-exec=true \
--cuda-graph-trace=node \ --cuda-graph-trace=node \
--capture-range=cudaProfilerApi \ --capture-range=cudaProfilerApi \
--capture-range-end repeat \ --capture-range-end repeat \
vllm serve meta-llama/Llama-3.1-8B-Instruct vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
# client # client
vllm bench serve \ vllm bench serve \

View File

@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables: NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
```bash ```bash
# Example UCX configuration, adjust according to your enviroment # Example UCX configuration, adjust according to your environment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1" export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
``` ```

View File

@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids return self.end_token_id in input_ids
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
return self.end_token_id in delta_token_ids
... ...
``` ```

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# enable torch profiler, can also be set on cmd line
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main(): def main():
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=1,
profiler_config={
"profiler": "torch",
"torch_profiler_dir": "./vllm_profile",
},
)
llm.start_profile() llm.start_profile()

View File

@ -75,7 +75,7 @@ torchgeo==0.7.0
mteb==2.1.2 mteb==2.1.2
# Data processing # Data processing
xgrammar==0.1.27 xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
# Test async scheduling # Test async scheduling
# Utilities # Utilities

View File

@ -17,7 +17,6 @@ def test_compile():
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
@pytest.mark.xfail
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch): def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"""Test that Qwen2.5-VL vision submodules are compiled. """Test that Qwen2.5-VL vision submodules are compiled.

View File

@ -80,6 +80,8 @@ def test_compile_ranges(use_fresh_inductor_cache):
vllm_config = VllmConfig( vllm_config = VllmConfig(
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192, max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
), ),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
@ -112,6 +114,8 @@ def test_compile_config_get_compile_ranges():
VllmConfig( VllmConfig(
scheduler_config=SchedulerConfig( scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192, max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
), ),
compilation_config=compilation_config, compilation_config=compilation_config,
) )
@ -134,6 +138,8 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192, max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
) )
torch.set_default_device("cuda") torch.set_default_device("cuda")

View File

@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest import pytest
import torch import torch
import vllm.config import vllm.config
import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.matcher_utils import QUANT_OPS
@ -237,13 +241,85 @@ def _generate_kernel_groupshape_combinations():
KERNEL_GROUPSHAPE_COMBINATIONS = _generate_kernel_groupshape_combinations() KERNEL_GROUPSHAPE_COMBINATIONS = _generate_kernel_groupshape_combinations()
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
self.w8a8_block_fp8_linear = [
TestBlockFP8Layer(
GroupShape(128, 128),
self.w[i],
self.wscale[i],
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
for i in range(3)
]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear[0](y)
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear[1](y2)
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
x4 = self.w8a8_block_fp8_linear[2](y3)
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS) @pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize(
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
) )
@ -253,9 +329,13 @@ def test_fusion_rmsnorm_quant(
num_tokens, num_tokens,
eps, eps,
kernel_groupshape, kernel_groupshape,
model_class,
enable_rms_norm_custom_op, enable_rms_norm_custom_op,
enable_quant_fp8_custom_op, enable_quant_fp8_custom_op,
): ):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
@ -290,7 +370,14 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config) if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
@ -325,7 +412,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the # there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant. # replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add). # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op: if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7 assert n_add_nodes(backend.graph_pre_pass) == 7

View File

@ -5,9 +5,14 @@ import copy
import pytest import pytest
import torch import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import (
CallableInductorPass,
InductorPass,
pass_context,
)
from vllm.compilation.pass_manager import PostGradPassManager from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.config.utils import Range
# dummy custom pass that doesn't inherit # dummy custom pass that doesn't inherit
@ -42,35 +47,37 @@ class ProperPass(InductorPass):
], ],
) )
def test_pass_manager_uuid(callable): def test_pass_manager_uuid(callable):
# Some passes need dtype to be set # Set the pass context as PassManager uuid uses it
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) with pass_context(Range(start=1, end=8)):
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager() pass_manager = PostGradPassManager()
pass_manager.configure(config) pass_manager.configure(config)
# Check that UUID is different if the same pass is added 2x # Check that UUID is different if the same pass is added 2x
pass_manager.add(callable) pass_manager.add(callable)
uuid1 = pass_manager.uuid() uuid1 = pass_manager.uuid()
pass_manager.add(callable) pass_manager.add(callable)
uuid2 = pass_manager.uuid() uuid2 = pass_manager.uuid()
assert uuid1 != uuid2 assert uuid1 != uuid2
# UUID should be the same as the original one, # UUID should be the same as the original one,
# as we constructed in the same way. # as we constructed in the same way.
pass_manager2 = PostGradPassManager() pass_manager2 = PostGradPassManager()
pass_manager2.configure(config) pass_manager2.configure(config)
pass_manager2.add(callable) pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid() assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change # UUID should be different due to config change
config2 = copy.deepcopy(config) config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.fuse_norm_quant = ( config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant not config2.compilation_config.pass_config.fuse_norm_quant
) )
config2.compilation_config.pass_config.fuse_act_quant = ( config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant not config2.compilation_config.pass_config.fuse_act_quant
) )
pass_manager3 = PostGradPassManager() pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2) pass_manager3.configure(config2)
pass_manager3.add(callable) pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid() assert uuid1 != pass_manager3.uuid()

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import ( from vllm.compilation.activation_quant_fusion import (
FUSED_OPS, FUSED_OPS,
@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]] return [FUSED_OPS[kNvfp4Quant]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]
def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
@pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch", "model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)], + [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
],
) )
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool, enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool, enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool, cuda_force_torch: bool,
): ):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.") pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
@ -172,9 +216,15 @@ def test_fusion_silu_and_mul_quant(
) )
with set_current_vllm_config(config): with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config) fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes) backend = TestBackend(*passes)
model = model_class( model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
@ -193,12 +243,14 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3 atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel: elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1 atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close( torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
) )
assert fusion_pass.matched_count == 1 assert sum([p.matched_count for p in fusion_passes]) == 1
# In pre-nodes, quant op should be present and fused kernels should not # In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem, ResponseFunctionToolCallOutputItem,
) )
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
) )
from vllm.entrypoints.responses_utils import ( from vllm.entrypoints.responses_utils import (
construct_chat_message_with_tool_call, _construct_single_message_from_response_item,
construct_chat_messages_with_tool_call,
convert_tool_responses_to_completions_format, convert_tool_responses_to_completions_format,
) )
@ -42,7 +44,43 @@ class TestResponsesUtils:
assert result == {"type": "function", "function": input_tool} assert result == {"type": "function", "function": input_tool}
def test_construct_chat_message_with_tool_call(self): def test_construct_chat_messages_with_tool_call(self):
"""Test construction of chat messages with tool calls."""
reasoning_item = ResponseReasoningItem(
id="lol",
summary=[],
type="reasoning",
content=[
Content(
text="Leroy Jenkins",
type="reasoning_text",
)
],
encrypted_content=None,
status=None,
)
mcp_tool_item = ResponseFunctionToolCall(
id="mcp_123",
call_id="call_123",
type="function_call",
status="completed",
name="python",
arguments='{"code": "123+456"}',
)
input_items = [reasoning_item, mcp_tool_item]
messages = construct_chat_messages_with_tool_call(input_items)
assert len(messages) == 1
message = messages[0]
assert message["role"] == "assistant"
assert message["reasoning"] == "Leroy Jenkins"
assert message["tool_calls"][0]["id"] == "call_123"
assert message["tool_calls"][0]["function"]["name"] == "python"
assert (
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
)
def test_construct_single_message_from_response_item(self):
item = ResponseReasoningItem( item = ResponseReasoningItem(
id="lol", id="lol",
summary=[], summary=[],
@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content=None, encrypted_content=None,
status=None, status=None,
) )
formatted_item = construct_chat_message_with_tool_call(item) formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert formatted_item["reasoning"] == "Leroy Jenkins" assert formatted_item["reasoning"] == "Leroy Jenkins"
@ -74,7 +112,7 @@ class TestResponsesUtils:
status=None, status=None,
) )
formatted_item = construct_chat_message_with_tool_call(item) formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert ( assert (
formatted_item["reasoning"] formatted_item["reasoning"]
@ -88,7 +126,7 @@ class TestResponsesUtils:
output="1234", output="1234",
status="completed", status="completed",
) )
formatted_item = construct_chat_message_with_tool_call(tool_call_output) formatted_item = _construct_single_message_from_response_item(tool_call_output)
assert formatted_item["role"] == "tool" assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234" assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp" assert formatted_item["tool_call_id"] == "temp"
@ -102,7 +140,7 @@ class TestResponsesUtils:
status=None, status=None,
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
construct_chat_message_with_tool_call(item) _construct_single_message_from_response_item(item)
output_item = ResponseOutputMessage( output_item = ResponseOutputMessage(
id="msg_bf585bbbe3d500e0", id="msg_bf585bbbe3d500e0",
@ -119,6 +157,6 @@ class TestResponsesUtils:
type="message", type="message",
) )
formatted_item = construct_chat_message_with_tool_call(output_item) formatted_item = _construct_single_message_from_response_item(output_item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert formatted_item["content"] == "dongyi" assert formatted_item["content"] == "dongyi"

View File

@ -7,7 +7,8 @@ import math
import pytest import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
if not current_platform.is_cpu(): if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True) pytest.skip("skipping CPU-only tests", allow_module_level=True)
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
] ]
def get_attn_isa(
block_size: int | None = None,
dtype: torch.dtype | None = None,
):
if block_size and dtype:
return _get_attn_isa(dtype, block_size)
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"
# rand number generation takes too much time, cache rand tensors # rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False) @functools.lru_cache(maxsize=128, typed=False)
def tensor_cache( def tensor_cache(
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
) )
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", QTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["neon"])
@pytest.mark.skipif(
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
reason="Not an Arm CPU.",
)
def test_varlen_with_paged_kv_normal_neon(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS) @pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96]) @pytest.mark.parametrize("head_size", [96])
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False]) @pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False]) @pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize("isa", [get_attn_isa()])
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_softcap( def test_varlen_with_paged_kv_softcap(
seq_lens: list[tuple[int, int]], seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int], num_heads: tuple[int, int],
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [True]) @pytest.mark.parametrize("use_alibi", [True])
@pytest.mark.parametrize("use_sink", [False]) @pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize("isa", [get_attn_isa()])
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_alibi( def test_varlen_with_paged_kv_alibi(
seq_lens: list[tuple[int, int]], seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int], num_heads: tuple[int, int],
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False]) @pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [True]) @pytest.mark.parametrize("use_sink", [True])
@pytest.mark.parametrize( @pytest.mark.parametrize("isa", [get_attn_isa()])
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_sink( def test_varlen_with_paged_kv_sink(
seq_lens: list[tuple[int, int]], seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int], num_heads: tuple[int, int],

View File

@ -26,7 +26,14 @@ def clear_cache():
_cached_get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) devices = ["cpu"]
if current_platform.is_cuda():
devices.append("cuda")
if current_platform.is_rocm():
devices.append("hip")
@pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str): def test_mha_attn_platform(device: str):
""" """
Test the attention selector between different platform and device. Test the attention selector between different platform and device.
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention

View File

@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max) .clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE) .to(FP8_DTYPE)
) )
return ref_out, ref_scale.view((1, 1)) return ref_out, ref_scale.view(1)
def native_w8a8_block_matmul( def native_w8a8_block_matmul(

View File

@ -54,6 +54,10 @@ def setup_cuda():
torch.set_default_device("cuda") torch.set_default_device("cuda")
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed", "num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
factor_for_scale = 1e-2 factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(current_platform.fp8_dtype())
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n n_tiles = (N + block_n - 1) // block_n
@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul(): def test_w8a8_block_fp8_cutlass_matmul():
# Test simple case where weight.shape % 128 != 0, # Test simple case where weight.shape % 128 != 0,
@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),

View File

@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
MNK_FACTORS = [ MNK_FACTORS = [
(1, 256, 128), (1, 256, 128),
(1, 16384, 1024), (1, 16384, 1024),

View File

@ -21,6 +21,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of # unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods # `is_quant_method_supported` conflates kernels with quantization methods

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64 import base64
import mimetypes import mimetypes
import os import os
@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion():
connector.fetch_image(broken_img) connector.fetch_image(broken_img)
@pytest.mark.flaky(reruns=3, reruns_delay=5)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) @pytest.mark.parametrize("num_frames", [-1, 32, 1800])
@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
} }
) )
video_sync, metadata_sync = connector.fetch_video(video_url) try:
video_async, metadata_async = await connector.fetch_video_async(video_url) video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
except (TimeoutError, asyncio.TimeoutError) as e:
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
assert np.array_equal(video_sync, video_async) assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async assert metadata_sync == metadata_async

View File

@ -10,10 +10,14 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
Fp8MoEMethod,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
MODELS = [ MODELS = [
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
), ),
) )
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
):
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
if method_cls is Fp8MoEMethod and weight_block_size is None:
pytest.skip(
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
"If this is your use case, consider using a restore function like #26327"
)
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
weight_block_size=weight_block_size,
)
if method_cls is Fp8LinearMethod:
layer = torch.nn.Linear(1, 1)
method = method_cls(config)
method.create_weights(
layer=layer,
input_size_per_partition=1,
output_partition_sizes=[1],
input_size=1,
output_size=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
else:
layer = FusedMoE(
num_experts=1,
top_k=1,
hidden_size=1,
intermediate_size=1,
)
method = method_cls(config, layer)
method.create_weights(
layer=layer,
num_experts=1,
hidden_size=1,
intermediate_size_per_partition=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
method.use_marlin = use_marlin
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
for name, param in layer.named_parameters()
]
# test loading
for name, shape, _ in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
# test reloading works after loading
# assuming that no reshaping occurred
for name, shape, original_weight_loader in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert weight_loader is original_weight_loader
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)

View File

@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
is False is False
) )
def test_is_reasoning_end_streaming(self, test_tokenizer):
"""Test the is_reasoning_end_streaming method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
assert (
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
is True
)
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
assert parser.is_reasoning_end_streaming([], []) is False
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id], [end_token_id]
)
is True
)
assert (
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
[2],
)
is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, 2], [2]
)
is False
)
def test_extract_content_ids(self, test_tokenizer): def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method.""" """Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer) parser = TestThinkingReasoningParser(test_tokenizer)

View File

@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
input_tokens = tokenizer.tokenize(input_text) input_tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(input_tokens) input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
assert parser.is_reasoning_end(input_ids) is True assert parser.is_reasoning_end(input_ids) is True
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
# Test extract_content_ids returns all input_ids # Test extract_content_ids returns all input_ids
assert parser.extract_content_ids(input_ids) == input_ids assert parser.extract_content_ids(input_ids) == input_ids

View File

@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming(
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"content_before_tool", "content_before_tool",
"complex",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming(
], ],
"bla", "bla",
), ),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"",
),
], ],
) )
def test_extract_tool_calls_streaming_one_chunk( def test_extract_tool_calls_streaming_one_chunk(

View File

@ -53,9 +53,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
@ -1381,8 +1378,8 @@ class TestBlockFP8Layer:
""" """
Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface. Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface.
This is a workaround until W8A8BlockFp8LinearOp has a similar API to This is a workaround until W8A8BlockFp8LinearOp implements
FP8ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization). ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
""" """
def __init__( def __init__(
@ -1390,20 +1387,21 @@ class TestBlockFP8Layer:
group_shape: GroupShape, group_shape: GroupShape,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: torch.Tensor, input_scale: torch.Tensor | None = None,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
): ):
self.kernel = None # For compatibility with TestFP8Layer interface
self.linear_op = W8A8BlockFp8LinearOp( self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]), weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape, act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(), cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=False, use_aiter_and_is_supported=use_aiter_and_is_supported,
) )
self.weight = weight self.weight = weight
self.weight_scale = weight_scale self.weight_scale = weight_scale
self.input_scale = input_scale self.input_scale = input_scale
def forward( def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.linear_op.apply( return self.linear_op.apply(

View File

@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size, num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,

View File

@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15) assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode # 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
) )
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE

View File

@ -10,6 +10,7 @@ from utils import (
BACKENDS, BACKENDS,
_extract_step_logprobs, _extract_step_logprobs,
_random_prompt, _random_prompt,
is_device_capability_below_90,
resolve_model_name, resolve_model_name,
skip_unsupported, skip_unsupported,
) )
@ -17,6 +18,8 @@ from utils import (
import vllm.model_executor.layers.batch_invariant as batch_invariant import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
@skip_unsupported @skip_unsupported
@pytest.mark.timeout(1000) @pytest.mark.timeout(1000)
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
max_model_len=2048, max_model_len=2048,
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
prompt = "the capital of france is" prompt = "the capital of france is"
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# Use a few test prompts # Use a few test prompts
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
max_model_len=max_model_len, max_model_len=max_model_len,
dtype="bfloat16", dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )

View File

@ -11,8 +11,10 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif( skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)), not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
reason="Requires CUDA and >= Hopper (SM90)", # Supports testing on Ampere and Ada Lovelace devices.
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
reason="Requires CUDA and >= Ampere (SM80)",
) )
BACKENDS: list[str] = [ BACKENDS: list[str] = [
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
return t, inner.token_ids return t, inner.token_ids
return None, None return None, None
def is_device_capability_below_90() -> bool:
return not current_platform.has_device_capability(90)

View File

@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams from vllm import SamplingParams
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.platforms import current_platform
from vllm.sampling_params import StructuredOutputsParams from vllm.sampling_params import StructuredOutputsParams
from vllm.v1.metrics.reader import Metric from vllm.v1.metrics.reader import Metric
@ -70,6 +71,18 @@ def test_without_spec_decoding(
(True, "uni", True, None, True), (True, "uni", True, None, True),
] ]
if current_platform.is_rocm():
# On ROCm, Only test with structured_outputs (deterministic)
# and skip chunk_prefill (more variable).
test_configs = [
cfg
for cfg in test_configs
if not cfg[4] # skip chunk_prefill=True
]
test_sampling_params = [
p for p in test_sampling_params if p.get("structured_outputs") is not None
]
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
] ]
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params) # On ROCm, use TRITON_ATTN + float32 for better numerical consistency
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
@ -117,13 +137,23 @@ def run_tests(
model: str, model: str,
test_configs: list[tuple], test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]], test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
): ):
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m: with monkeypatch.context() as m:
# avoid precision errors # avoid precision errors
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
# m.setenv("VLLM_BATCH_INVARIANT", "1") # m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list, list]] = [] outputs: list[tuple[str, list, list]] = []
for n, ( for n, (
@ -143,6 +173,7 @@ def run_tests(
async_scheduling, async_scheduling,
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
) )
outputs.append(test_results) outputs.append(test_results)
@ -172,17 +203,34 @@ def run_tests(
name_0=f"baseline=[{baseline_config}], params={params}", name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}", name_1=f"config=[{test_config}], params={params}",
) )
assert _all_logprobs_match(base_logprobs, test_logprobs)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
if ( if (
base_acceptance_rate is not None base_acceptance_rate is not None
and test_acceptance_rate is not None and test_acceptance_rate is not None
): ):
if "spec_mml=None" in test_config: if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert ( assert (
test_acceptance_rate > base_acceptance_rate test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=5e-2) == pytest.approx(base_acceptance_rate, rel=tolerance)
) )
else: else:
# Currently the reported acceptance rate is expected to be # Currently the reported acceptance rate is expected to be
@ -213,6 +261,7 @@ def run_test(
async_scheduling: bool, async_scheduling: bool,
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
@ -231,6 +280,15 @@ def run_test(
print("-" * 80) print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}") print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80) print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner( with VllmRunner(
model, model,
max_model_len=512, max_model_len=512,
@ -240,7 +298,7 @@ def run_test(
# enforce_eager=True, # enforce_eager=True,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=executor, distributed_executor_backend=executor,
dtype="float32", # avoid precision errors dtype=dtype,
speculative_config=spec_config, speculative_config=spec_config,
disable_log_stats=False, disable_log_stats=False,
**cache_arg, **cache_arg,
@ -300,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all( if current_platform.is_rocm():
a.decoded_token == b.decoded_token # ROCm has higher numerical variance
and a.rank == b.rank # due to use of float16.
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6) rel_tol, abs_tol = 5e-2, 1e-5
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a) else:
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()
and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)
) )

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()

View File

@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve. # Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 82.5% acceptance rate at the end. # Heuristic: expect at least 80.0% acceptance rate at the end.
assert last_accept_rate > 0.825 assert last_accept_rate > 0.80
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(), query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(), _seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(), _num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size, num_reqs=batch_size,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,

View File

@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request.use_structured_output = True request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5] request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8] request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
return request return request
def test_should_fill_bitmask_with_enable_in_reasoning( def test_should_fill_bitmask_with_enable_in_reasoning(

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import vllm.envs as envs from vllm.config import ProfilerConfig
from vllm.profiler.gpu_profiler import WorkerProfiler from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler): class ConcreteWorkerProfiler(WorkerProfiler):
@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes. A basic implementation of a worker profiler for testing purposes.
""" """
def __init__(self): def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0 self.start_call_count = 0
self.stop_call_count = 0 self.stop_call_count = 0
self.should_fail_start = False self.should_fail_start = False
super().__init__() super().__init__(profiler_config)
def _start(self) -> None: def _start(self) -> None:
if self.should_fail_start: if self.should_fail_start:
@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self.stop_call_count += 1 self.stop_call_count += 1
@pytest.fixture(autouse=True) @pytest.fixture
def reset_mocks(): def default_profiler_config():
"""Fixture to reset mocks and env variables before each test.""" return ProfilerConfig(
envs.VLLM_PROFILER_DELAY_ITERS = 0 profiler="torch",
envs.VLLM_PROFILER_MAX_ITERS = 0 torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop(): def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay.""" """Test standard start without delay."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
assert profiler._running is True assert profiler._running is True
assert profiler._active is True assert profiler._active is True
@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_delayed_start(): def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting.""" """Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start # User requests start
profiler.start() profiler.start()
@ -71,10 +73,10 @@ def test_delayed_start():
assert profiler.start_call_count == 1 assert profiler.start_call_count == 1
def test_max_iterations(): def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations.""" """Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2 default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
assert profiler._running is True assert profiler._running is True
@ -95,12 +97,11 @@ def test_max_iterations():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters(): def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations.""" """Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
envs.VLLM_PROFILER_MAX_ITERS = 2 default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
# Step 1 # Step 1
@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_idempotency(): def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic.""" """Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start # Double Start
profiler.start() profiler.start()
@ -142,10 +143,10 @@ def test_idempotency():
assert profiler.stop_call_count == 1 # Should only stop once assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive(): def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing.""" """Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet # Not started yet
profiler.step() profiler.step()
@ -155,9 +156,9 @@ def test_step_inactive():
assert profiler.start_call_count == 0 assert profiler.start_call_count == 0
def test_start_failure(): def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception.""" """Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True profiler.should_fail_start = True
profiler.start() profiler.start()
@ -168,9 +169,9 @@ def test_start_failure():
assert profiler.start_call_count == 0 # Logic failed inside start assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown(): def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running.""" """Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running # Case 1: Not running
profiler.shutdown() profiler.shutdown()
@ -182,10 +183,10 @@ def test_shutdown():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop(): def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period.""" """Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5 default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
profiler.step() profiler.step()

View File

@ -9,6 +9,8 @@ import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
_FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool: def is_aiter_found() -> bool:
from importlib.util import find_spec from importlib.util import find_spec
@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks. # we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found() IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable: def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if """Decorator that only executes the function if
@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl( def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@ -467,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual) return torch.empty_like(x), torch.empty_like(residual)
def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import per_tensor_quant_hip
return per_tensor_quant_hip(x, scale, quant_dtype)
def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
1, dtype=torch.float32, device=x.device
)
def _rocm_aiter_per_token_quant_impl(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
out,
x,
scale,
scale_ub=None,
shuffle_scale=False,
num_rows=None,
num_rows_factor=1,
)
return out, scale
def _rocm_aiter_per_token_quant_fake(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once # Global flag to ensure ops are registered only once
_OPS_REGISTERED = False _OPS_REGISTERED = False
@ -502,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported @if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool: def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable.""" """ "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() return cls.is_linear_enabled()
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
@ -577,14 +747,6 @@ class rocm_aiter_ops:
) )
# register all the custom ops here # register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1", op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl, op_func=_rocm_aiter_asm_moe_tkw1_impl,
@ -644,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale", op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rms_norm", op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl, op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake, fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_per_tensor_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
@staticmethod @staticmethod
@ -859,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale=kv_scale, kv_scale=kv_scale,
) )
@staticmethod
def per_tensor_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
@staticmethod
def per_token_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod @staticmethod
def triton_fp4_gemm_dynamic_qaunt( def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor, x: torch.Tensor,

View File

@ -1726,7 +1726,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub output, input, scale, scale_ub
) )
else: else:
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
assert scale.numel() == 1, f"{scale.shape}" assert scale.numel() == 1, f"{scale.shape}"

View File

@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else: else:
flash_attn_varlen_func = None flash_attn_varlen_func = None

View File

@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached # needed here to know how many tokens to attend to from the cached
# cross-attention KV cache. # cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy( new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu common_attn_metadata.encoder_seq_lens_cpu
) )

View File

@ -12,7 +12,6 @@ from typing import Any
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
raise OSError(
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
"Please set it to a valid path to use torch profiler."
)
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
if args.profile and not engine_args.profiler_config.profiler == "torch":
raise ValueError(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected. # Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
run_to_completion(profile_dir=None) run_to_completion(profile_dir=None)
if args.profile: if args.profile:
profile_dir = envs.VLLM_TORCH_PROFILER_DIR profile_dir = engine_args.profiler_config.torch_profiler_dir
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir) run_to_completion(profile_dir=profile_dir)
return return

View File

@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
help="Use Torch Profiler. The endpoint must be launched with " help="Use vLLM Profiling. --profiler-config must be provided on the server.",
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--save-result", "--save-result",

View File

@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--profile", "--profile",
action="store_true", action="store_true",
default=False, default=False,
help="Use Torch Profiler. The env variable " help="Use vLLM Profiling. --profiler-config must be provided on the server.",
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
) )
# prefix repetition dataset # prefix repetition dataset

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools import functools
import hashlib import hashlib
import inspect import inspect
@ -8,15 +10,17 @@ import json
import types import types
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import TYPE_CHECKING, Any
import torch import torch
from torch import fx from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.config.utils import Range
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
from vllm.config.utils import Range
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass from torch._inductor.custom_graph_pass import CustomGraphPass
else: else:

View File

@ -5,6 +5,7 @@ import functools
from torch import fx as fx from torch import fx as fx
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass from .fusion import RMSNormQuantFusionPass
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant: if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)] self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant: if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant: if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)] self.passes += [AttnFusionPass(config)]

View File

@ -53,8 +53,27 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1 self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges() self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string) logger.debug_once(log_string)

View File

@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)

View File

@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.speech_to_text import SpeechToTextConfig
@ -89,6 +90,8 @@ __all__ = [
"SpeechToTextConfig", "SpeechToTextConfig",
# From vllm.config.structured_outputs # From vllm.config.structured_outputs
"StructuredOutputsConfig", "StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils # From vllm.config.utils
"ConfigType", "ConfigType",
"SupportsMetricsInfo", "SupportsMetricsInfo",

199
vllm/config/profiler.py Normal file
View File

@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Get field from env var if set, with deprecation warning."""
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--profiler-config.%s command line argument or "
"ProfilerConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
return value
return None
def _set_from_env_if_set(
self,
field_name: str,
env_var_name: str,
to_bool: bool = True,
to_int: bool = False,
) -> None:
"""Set field from env var if set, with deprecation warning."""
value = self._get_from_env_if_set(field_name, env_var_name)
if value is not None:
if to_bool:
value = value == "1"
if to_int:
value = int(value)
setattr(self, field_name, value)
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
maybe_use_cuda_profiler = self._get_from_env_if_set(
"profiler", "VLLM_TORCH_CUDA_PROFILE"
)
if maybe_use_cuda_profiler is not None:
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
else:
self._set_from_env_if_set(
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
)
if self.torch_profiler_dir:
self.profiler = "torch"
self._set_from_env_if_set(
"torch_profiler_record_shapes",
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
)
self._set_from_env_if_set(
"torch_profiler_with_memory",
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
)
self._set_from_env_if_set(
"torch_profiler_with_stack",
"VLLM_TORCH_PROFILER_WITH_STACK",
)
self._set_from_env_if_set(
"torch_profiler_with_flops",
"VLLM_TORCH_PROFILER_WITH_FLOPS",
)
self._set_from_env_if_set(
"ignore_frontend",
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
)
self._set_from_env_if_set(
"torch_profiler_use_gzip",
"VLLM_TORCH_PROFILER_USE_GZIP",
)
self._set_from_env_if_set(
"torch_profiler_dump_cuda_time_total",
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
)
self._set_from_env_if_set(
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
)
self._set_from_env_if_set(
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
)
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
return self

View File

@ -39,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig from .model import ModelConfig
from .observability import ObservabilityConfig from .observability import ObservabilityConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
@ -218,6 +219,8 @@ class VllmConfig:
You can specify the full compilation config like so: You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
""" """
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer.""" """The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
@ -296,6 +299,8 @@ class VllmConfig:
vllm_factors.append("None") vllm_factors.append("None")
if self.structured_outputs_config: if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash()) vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
@ -1042,8 +1047,14 @@ class VllmConfig:
self.compilation_config.max_cudagraph_capture_size self.compilation_config.max_cudagraph_capture_size
) )
if max_cudagraph_capture_size is None: if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min( max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512 self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
) )
max_num_tokens = self.scheduler_config.max_num_batched_tokens max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

View File

@ -50,6 +50,7 @@ from vllm.config import (
ObservabilityConfig, ObservabilityConfig,
ParallelConfig, ParallelConfig,
PoolerConfig, PoolerConfig,
ProfilerConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig, SpeculativeConfig,
StructuredOutputsConfig, StructuredOutputsConfig,
@ -532,6 +533,8 @@ class EngineArgs:
worker_cls: str = ParallelConfig.worker_cls worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls
profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")
kv_transfer_config: KVTransferConfig | None = None kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
@ -1164,7 +1167,7 @@ class EngineArgs:
vllm_group.add_argument( vllm_group.add_argument(
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
) )
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
vllm_group.add_argument( vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"] "--optimization-level", **vllm_kwargs["optimization_level"]
) )
@ -1782,6 +1785,7 @@ class EngineArgs:
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config, ec_transfer_config=self.ec_transfer_config,
profiler_config=self.profiler_config,
additional_config=self.additional_config, additional_config=self.additional_config,
optimization_level=self.optimization_level, optimization_level=self.optimization_level,
) )

View File

@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
# These constants help mitigate header abuse attacks # These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256 H11_MAX_HEADER_COUNT_DEFAULT = 256
MCP_PREFIX = "mcp_"

View File

@ -19,6 +19,7 @@ from vllm import envs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.parser.harmony_utils import (
get_encoding, get_encoding,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
result_str = result.content[0].text result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem( message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}", id=f"mcpo_{random_uuid()}",
type="function_call_output", type="function_call_output",
call_id=f"call_{random_uuid()}", call_id=f"call_{random_uuid()}",
output=result_str, output=result_str,
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
if not self.parser.response_messages: if not self.parser.response_messages:
return [] return []
last_msg = self.parser.response_messages[-1] last_msg = self.parser.response_messages[-1]
# change this to a mcp_ function call
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
self.parser.response_messages[-1] = last_msg
if last_msg.name == "code_interpreter": if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg) return await self.call_python_tool(self._tool_sessions["python"], last_msg)
elif last_msg.name == "web_search_preview": elif last_msg.name == "web_search_preview":

View File

@ -20,6 +20,7 @@ from vllm.beam_search import (
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
PoolerConfig, PoolerConfig,
ProfilerConfig,
StructuredOutputsConfig, StructuredOutputsConfig,
is_init_field, is_init_field,
) )
@ -211,6 +212,7 @@ class LLM:
structured_outputs_config: dict[str, Any] structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig | StructuredOutputsConfig
| None = None, | None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
kv_cache_memory_bytes: int | None = None, kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None,
@ -282,6 +284,20 @@ class LLM:
else: else:
structured_outputs_instance = StructuredOutputsConfig() structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
else:
profiler_config_instance = ProfilerConfig()
# warn about single-process data parallel usage. # warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1)) _dp_size = int(kwargs.get("data_parallel_size", 1))
_distributed_executor_backend = kwargs.get("distributed_executor_backend") _distributed_executor_backend = kwargs.get("distributed_executor_backend")
@ -324,6 +340,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config, pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance, structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
logits_processors=logits_processors, logits_processors=logits_processors,
**kwargs, **kwargs,

View File

@ -1339,6 +1339,7 @@ class OpenAIServing:
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0] request_prompt = request_prompts[0]
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len( sampling_params.max_tokens = self.max_model_len - len(

View File

@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]" self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token) self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if not _is_pre_v11_tokeniser(self.model_tokenizer): self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
if self.bot_token_id is None: if self.bot_token_id is None:
raise RuntimeError( raise RuntimeError(
@ -148,23 +143,24 @@ class MistralToolParser(ToolParser):
tool_content = model_output.replace(self.bot_token, "").strip() tool_content = model_output.replace(self.bot_token, "").strip()
try: try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try: try:
if self.fn_name_regex: if not self._is_pre_v11:
function_call_arr = [] function_call_arr = []
for single_tool_content in model_output.split(self.bot_token): for single_tool_content in model_output.split(self.bot_token):
matches = self.fn_name_regex.findall(single_tool_content) if "{" not in single_tool_content:
continue
for match in matches: end_name = single_tool_content.find("{")
fn_name = match[0] fn_name, args = (
args = match[1] single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump # fn_name is encoded outside serialized json dump
# only arguments are serialized # only arguments are serialized
function_call_arr.append( function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)} {"name": fn_name, "arguments": json.loads(args)}
) )
else: else:
function_call_arr = json.loads(tool_content) function_call_arr = json.loads(tool_content)
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
from vllm import envs from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ResponseInputOutputItem, ResponseInputOutputItem,
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
) )
if isinstance(output_messages[-1], ResponseFunctionToolCall): if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall( mcp_message = McpCall(
id=f"mcp_{random_uuid()}", id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments, arguments=output_messages[-1].arguments,
name=output_messages[-1].name, name=output_messages[-1].name,
server_label=output_messages[ server_label=output_messages[
-1 -1
].name, # TODO: store the server label ].name, # TODO: store the server label
type="mcp_call", type=f"{MCP_PREFIX}call",
status="completed", status="completed",
output=message.output, output=message.output,
# TODO: support error output # TODO: support error output
@ -98,12 +99,63 @@ def construct_input_messages(
if isinstance(request_input, str): if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input}) messages.append({"role": "user", "content": request_input})
else: else:
for item in request_input: input_messages = construct_chat_messages_with_tool_call(request_input)
messages.append(construct_chat_message_with_tool_call(item)) messages.extend(input_messages)
return messages return messages
def construct_chat_message_with_tool_call( def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem, item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall): if isinstance(item, ResponseFunctionToolCall):

View File

@ -5,7 +5,7 @@
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response from fastapi.responses import Response
import vllm.envs as envs from vllm.config import ProfilerConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request):
def attach_router(app: FastAPI): def attach_router(app: FastAPI):
if envs.VLLM_TORCH_PROFILER_DIR: profiler_config = getattr(app.state.args, "profiler_config", None)
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
if profiler_config is not None and profiler_config.profiler is not None:
logger.warning_once( logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be " "Profiler with mode '%s' is enabled in the "
"used for local development!" "API server. This should ONLY be used for local development!",
profiler_config.profiler,
) )
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
app.include_router(router) app.include_router(router)

View File

@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
MAX_JOBS: str | None = None MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
@ -88,20 +89,23 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_TORCH_CUDA_PROFILE: bool = False # Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
VLLM_TORCH_PROFILER_DIR: str | None = None VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
VLLM_PROFILER_DELAY_ITERS: str | None = None
VLLM_PROFILER_MAX_ITERS: str | None = None
# End of deprecated env variables for profiling
VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_PROFILER_DELAY_ITERS: int = 0
VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
@ -452,6 +456,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Main CUDA version of vLLM. This follows PyTorch but can be overridden. # Main CUDA version of vLLM. This follows PyTorch but can be overridden.
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9", or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION",
"highest",
["highest", "high", "medium"],
case_sensitive=False,
),
# Maximum number of compilation jobs to run in parallel. # Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs # By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
@ -841,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None "VLLM_LORA_RESOLVER_CACHE_DIR", None
), ),
# Enables torch CUDA profiling if set. # Enables torch CUDA profiling if set to 1.
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered. # Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: bool( "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
),
# Enables torch profiler if set. # Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers' # Deprecated, see profiler_config.
# traces (CPU & GPU) will be saved under this directory. "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
# Note that it must be an absolute path. # Enable torch profiler to record shapes if set to 1.
"VLLM_TORCH_PROFILER_DIR": lambda: ( # Deprecated, see profiler_config.
None "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
else (
val
if val.startswith("gs://") and val[5:] and val[5] != "/"
else os.path.abspath(os.path.expanduser(val))
)
), ),
# Enable torch profiler to record shapes if set # Enable torch profiler to profile memory if set to 1.
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will # Deprecated, see profiler_config.
# not record shapes. "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"
), ),
# Enable torch profiler to profile memory if set # Enable torch profiler to profile stack if set to 1.
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler # Deprecated, see profiler_config.
# will not profile memory. "VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"
), ),
# Enable torch profiler to profile stack if set # Enable torch profiler to profile flops if set to 1.
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL # Deprecated, see profiler_config.
# profile stack by default. "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
), ),
# Enable torch profiler to profile flops if set # Disable torch profiling of the AsyncLLMEngine process if set to 1.
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will # Deprecated, see profiler_config.
# not profile flops. "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
# Disable torch profiling of the AsyncLLMEngine process.
# If set to 1, will not profile the engine process.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
), ),
# Delay number of iterations before starting profiling when using # Delay number of iterations before starting profiling when using
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately. # the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
"VLLM_PROFILER_DELAY_ITERS": lambda: int( # Deprecated, see profiler_config.
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0") "VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
),
# Maximum number of iterations to profile when using the torch/torch CUDA profiler. # Maximum number of iterations to profile when using the torch/torch CUDA profiler.
# If set to 0, will not limit the number of iterations. # If set to 0, will not limit the number of iterations.
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), "VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
# Control whether torch profiler gzip-compresses profiling files. # Control whether torch profiler gzip-compresses profiling files.
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default). # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool( "VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
),
# Control whether torch profiler dumps the self_cuda_time_total table. # Control whether torch profiler dumps the self_cuda_time_total table.
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping # Set to 0 to disable dumping the table.
# (enabled by default). # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool( "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0" os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
), ),
# If set, vLLM will use Triton implementations of AWQ. # If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

View File

@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None: if num_tokens_across_dp is None:
assert ubatch_slices is None assert ubatch_slices is None
assert num_tokens is not None assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp( _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config, parallel_config=vllm_config.parallel_config,
allow_microbatching=False, allow_microbatching=False,

View File

@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides # Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"): if not is_torch_equal_or_newer("2.10.0.dev"):
if current_platform.is_device_capability(100): if (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1 # For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735 # Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")

View File

@ -4,7 +4,10 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
@ -49,6 +52,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute", "FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat", "FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize", "FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE", "SharedFusedMoE",
"activation_without_mul", "activation_without_mul",
"override_config", "override_config",

View File

@ -895,6 +895,48 @@ def get_moe_configs(
return None return None
def _ensure_block_size_k_divisible(
size_k: int, block_size_k: int, group_size: int
) -> int:
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
Args:
size_k: The size_k dimension that must be divisible by result.
block_size_k: Preferred block size (will be adjusted if needed).
group_size: The result must be divisible by this.
Returns:
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
"""
# Fast path: already valid
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
return block_size_k
# Find the largest value that:
# 1. Divides size_k (size_k % candidate == 0)
# 2. Is divisible by group_size (candidate % group_size == 0)
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
#
# Strategy: Search from min(block_size_k, size_k) down to group_size,
# stepping by group_size to ensure divisibility by group_size
max_search = min(block_size_k, size_k)
start = (max_search // group_size) * group_size
for candidate in range(start, group_size - 1, -group_size):
if size_k % candidate == 0:
return candidate
# Fallback: if group_size divides size_k, use it
# This should always be true with correct group_size configuration
if size_k % group_size == 0:
return group_size
# This should not happen with correct group_size, but ensure divisibility
return size_k
def get_moe_wna16_block_config( def get_moe_wna16_block_config(
config: dict[str, int], config: dict[str, int],
use_moe_wna16_cuda: bool, use_moe_wna16_cuda: bool,
@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
# at the same time. # at the same time.
block_size_n = 1024 block_size_n = 1024
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable
import torch import torch
@ -100,22 +99,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
@ -97,23 +96,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -127,10 +109,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=self.allow_inplace, inplace=self.allow_inplace,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:

View File

@ -33,10 +33,6 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data, init_aiter_topK_meta_data,
) )
@ -57,11 +53,8 @@ from vllm.utils.torch_utils import (
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record, fused_experts from .fused_moe import eplb_map_to_physical_and_record
else: else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = object # type: ignore
FusedMoEPrepareAndFinalize = object # type: ignore
def _eplb_map_to_physical_and_record( def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
@ -483,7 +476,7 @@ class FusedMoE(CustomOp):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
) )
self.expert_map: torch.Tensor | None self._expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
@ -493,7 +486,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables() self._maybe_init_expert_routing_tables()
logger.info_once( logger.info_once(
@ -506,10 +499,10 @@ class FusedMoE(CustomOp):
self.expert_placement_strategy, self.expert_placement_strategy,
self.local_num_experts, self.local_num_experts,
self.global_num_experts, self.global_num_experts,
get_compressed_expert_map(self.expert_map), get_compressed_expert_map(self._expert_map),
) )
else: else:
self.local_num_experts, self.expert_map, self.expert_mask = ( self.local_num_experts, self._expert_map, self.expert_mask = (
self.global_num_experts, self.global_num_experts,
None, None,
None, None,
@ -781,7 +774,7 @@ class FusedMoE(CustomOp):
), ),
) )
if self.expert_map is None: if self._expert_map is None:
return None return None
routing_tables = self.ensure_round_robin_expert_routing_tables( routing_tables = self.ensure_round_robin_expert_routing_tables(
@ -789,7 +782,7 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts, local_num_experts=self.local_num_experts,
device=self.expert_map.device, device=self._expert_map.device,
) )
global_to_physical, physical_to_global, local_global = routing_tables global_to_physical, physical_to_global, local_global = routing_tables
@ -840,8 +833,8 @@ class FusedMoE(CustomOp):
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self._expert_map is not None
with self.expert_map.device: with self._expert_map.device:
local_num_experts, expert_map, expert_mask = determine_expert_map( local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
@ -851,7 +844,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled, return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables() self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled: if self.aiter_fmoe_shared_expert_enabled:
@ -888,7 +881,7 @@ class FusedMoE(CustomOp):
# Record that the clone will be used by shared_experts_stream # Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone # to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream()) # NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output. # because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream) hidden_states_clone.record_stream(self.shared_experts_stream)
@ -1068,9 +1061,9 @@ class FusedMoE(CustomOp):
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None: if self._expert_map is None:
return expert_id return expert_id
return self.expert_map[expert_id].item() return self._expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer( def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int self, vllm_config: VllmConfig, dp_size: int
@ -1744,6 +1737,12 @@ class FusedMoE(CustomOp):
reduce_output(fused_output)[..., :og_hidden_states], reduce_output(fused_output)[..., :og_hidden_states],
) )
@property
def expert_map(self) -> torch.Tensor | None:
return (
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
)
def forward_cuda( def forward_cuda(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1805,24 +1804,6 @@ class FusedMoE(CustomOp):
layer=self, layer=self,
x=staged_hidden_states, x=staged_hidden_states,
router_logits=staged_router_logits, router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
if has_separate_shared_experts: if has_separate_shared_experts:
@ -1968,25 +1949,6 @@ class FusedMoE(CustomOp):
if do_naive_dispatch_combine if do_naive_dispatch_combine
else hidden_states, else hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
if has_separate_shared_experts: if has_separate_shared_experts:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -269,53 +268,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
return self.forward( return self.forward(
x=x,
layer=layer, layer=layer,
x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
@ -333,24 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts( topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -364,9 +307,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
expert_map=expert_map, expert_map=layer.expert_map,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
elif self.flashinfer_cutlass_moe_enabled: elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe( return self.flashinfer_cutlass_moe(
@ -375,8 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
result = fused_experts( result = fused_experts(
@ -386,11 +329,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
) )
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
@ -405,148 +348,101 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for CPU.") raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe( return layer.cpu_fused_moe(
layer, layer,
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
global_num_experts, layer.global_num_experts,
expert_map, layer.expert_map,
custom_routing_function, layer.custom_routing_function,
scoring_func, layer.scoring_func,
routed_scaling_factor, layer.routed_scaling_factor,
e_score_correction_bias, layer.e_score_correction_bias,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
activation, layer.activation,
) )
def forward_xpu( def forward_xpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for XPU.") raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
) )
def forward_tpu( def forward_tpu(
self, self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not use_grouped_topk assert not layer.use_grouped_topk
assert num_expert_group is None assert layer.num_expert_group is None
assert topk_group is None assert layer.topk_group is None
assert custom_routing_function is None assert layer.custom_routing_function is None
assert apply_router_weight_on_input is False assert layer.apply_router_weight_on_input is False
if scoring_func != "softmax": if layer.scoring_func != "softmax":
raise NotImplementedError( raise NotImplementedError(
"Only softmax scoring function is supported for TPU." "Only softmax scoring function is supported for TPU."
) )
if e_score_correction_bias is not None: if layer.e_score_correction_bias is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert score correction bias is not supported for TPU." "Expert score correction bias is not supported for TPU."
) )
assert activation == "silu", f"{activation} is not supported for TPU." assert layer.activation == "silu", (
assert routed_scaling_factor == 1.0, ( f"{layer.activation} is not supported for TPU."
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." )
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
) )
if ( if (
enable_eplb is not False layer.enable_eplb is not False
or expert_load_view is not None or layer.expert_load_view is not None
or logical_to_physical_map is not None or layer.logical_to_physical_map is not None
or logical_replica_count is not None or layer.logical_replica_count is not None
): ):
raise NotImplementedError("Expert load balancing is not supported for TPU.") raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas( return fused_moe_pallas(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk=top_k, topk=layer.top_k,
gating_output=router_logits, gating_output=router_logits,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
renormalize=renormalize, renormalize=layer.renormalize,
) )
if current_platform.is_tpu(): if current_platform.is_tpu():

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
@ -669,25 +668,8 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -708,9 +690,9 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace, workspace=layer.workspace,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Union from typing import Any, Union
import torch import torch
@ -498,23 +497,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -534,10 +516,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
from collections.abc import Callable
from enum import Enum from enum import Enum
import torch import torch
@ -558,31 +557,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if ( if (
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
) )
@ -591,12 +573,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, top_k=layer.top_k,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
@ -619,9 +601,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -646,15 +628,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert expert_map is None, ( assert layer.expert_map is None, (
"Expert Parallelism / expert_map " "Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod." "CompressedTensorsW4A4Nvfp4MoEMethod."
@ -670,7 +652,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments # TODO(bnell): derive these from arguments
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
@ -1188,23 +1170,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -1215,7 +1180,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
if self.use_marlin: if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -1228,9 +1195,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -1248,9 +1215,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1270,10 +1237,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None
if self.disable_expert_map
else layer.expert_map, # ???
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
else: else:
@ -1290,9 +1259,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
ab_strides1=self.ab_strides1_c_strides2, ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2, ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1, c_strides1=self.c_strides1,
@ -1314,10 +1283,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1437,23 +1406,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -1469,10 +1421,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -1814,25 +1766,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -1853,9 +1790,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx, g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx, g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
@ -2057,23 +1994,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -2089,10 +2009,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -2372,32 +2292,15 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in ("silu", "swigluoai", "swiglu"), ( assert layer.activation in ("silu", "swigluoai", "swiglu"), (
"Only SiLU/SwiGLUGU/SwiGLUUG are supported." "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
) )
assert expert_map is None, """expert_map/EP not implemented assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE.""" for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int: def _act_kind(s: str) -> int:
@ -2414,15 +2317,9 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, top_k=layer.top_k,
top_k=top_k, use_grouped_topk=layer.use_grouped_topk,
renormalize=renormalize, renormalize=layer.renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
) )
return torch.ops._C.dynamic_4bit_int_moe( return torch.ops._C.dynamic_4bit_int_moe(
@ -2435,8 +2332,8 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_in_features, layer.w2_in_features,
layer.w13_out_features, layer.w13_out_features,
layer.group_size, layer.group_size,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
int(_act_kind(activation)), int(_act_kind(layer.activation)),
) )
@ -2707,28 +2604,11 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
) )
@ -2749,9 +2629,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else layer.expert_map,
a_strides1=self.a_strides1_c_strides2, a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2, a_strides2=self.a_strides2,
b_strides1=self.b_strides1, b_strides1=self.b_strides1,

View File

@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# dont restrict as emulations # don't restrict as emulations
return 80 return 80
def create_weights( def create_weights(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -140,23 +139,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -172,10 +154,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -99,7 +98,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
@ -559,46 +558,50 @@ class Fp8LinearMethod(LinearMethodBase):
assert not self.act_q_static assert not self.act_q_static
size_k_first = False size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy( weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv layer.weight, layer.weight_scale_inv
) )
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter # Update layer with new values
del layer.weight_scale_inv replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else: else:
weight = layer.weight if not self.quant_config.is_checkpoint_fp8_serialized:
weight_scale = layer.weight_scale qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If using w8a8, torch._scaled_mm needs per tensor, so # If checkpoint is fp8 per-tensor, handle that there are N scales for N
# requantize the logical shards as a single weight. # shards in a fused module
if not self.use_marlin: else:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( weight = layer.weight
weight, weight_scale = layer.weight_scale
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values. # If using w8a8, torch._scaled_mm needs per tensor, so
layer.weight = Parameter(weight.data, requires_grad=False) # requantize the logical shards as a single weight.
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) if not self.use_marlin:
layer.input_scale = ( weight, weight_scale, input_scale = (
Parameter(input_scale, requires_grad=False) process_fp8_weight_tensor_strategy(
if input_scale is not None weight,
else None weight_scale,
) layer.logical_widths,
getattr(layer, "input_scale", None),
)
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if input_scale is not None:
replace_parameter(layer, "input_scale", input_scale)
else:
layer.input_scale = None
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin( prepare_fp8_layer_for_marlin(
@ -625,7 +628,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply( return self.w8a8_block_fp8_linear.apply(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
) )
@ -654,10 +657,15 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.nn.functional.linear(x, weight_bf16.t(), bias) return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin: if self.use_marlin:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear( return apply_fp8_marlin_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=weight_scale,
workspace=layer.workspace, workspace=layer.workspace,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
@ -671,7 +679,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply( return self.w8a8_block_fp8_linear.apply(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
) )
@ -941,22 +949,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale_inv = layer.w2_weight_scale_inv w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses. # torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False) replace_parameter(layer, "w13_weight", w13_weight)
layer.w13_weight_scale_inv = Parameter( replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
w13_weight_scale_inv, requires_grad=False replace_parameter(layer, "w2_weight", w2_weight)
) replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
w2_weight_scale_inv, requires_grad=False
)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
) )
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) replace_parameter(layer, "w13_weight", shuffled_w13)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) replace_parameter(layer, "w2_weight", shuffled_w2)
# DeepGemm scales need to be transposed and aligned. We try to do # DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons. # it ahead of time for performance reasons.
@ -994,13 +998,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize # Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor. # merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter( replace_parameter(
layer,
"w13_weight_scale",
torch.ones( torch.ones(
layer.local_num_experts, layer.local_num_experts,
dtype=torch.float32, dtype=torch.float32,
device=w13_weight.device, device=w13_weight.device,
), ),
requires_grad=False,
) )
for expert in range(layer.local_num_experts): for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
@ -1009,16 +1014,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
) )
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) replace_parameter(layer, "w13_weight", w13_weight)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) replace_parameter(layer, "w13_weight", shuffled_w13)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight # MoE kernels require single activation scale and single weight
# scale for w13 per expert. # scale for w13 per expert.
@ -1039,12 +1045,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"fp8 MoE layer. Using the maximum across experts " "fp8 MoE layer. Using the maximum across experts "
"for each layer." "for each layer."
) )
layer.w13_input_scale = torch.nn.Parameter( replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
layer.w13_input_scale.max(), requires_grad=False replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = ( w13_weight, w13_weight_scale, w13_input_scale = (
@ -1058,22 +1060,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
) )
# Reset the parameter # Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) replace_parameter(layer, "w13_weight", w13_weight)
layer.w13_weight_scale = torch.nn.Parameter( replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
w13_weight_scale, requires_grad=False
)
if w13_input_scale is not None: if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter( replace_parameter(layer, "w13_input_scale", w13_input_scale)
w13_input_scale, requires_grad=False replace_parameter(layer, "w2_weight", w2_weight)
) replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
if w2_input_scale is not None: if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter( replace_parameter(layer, "w2_input_scale", w2_input_scale)
w2_input_scale, requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert. # We take the max then dequant and requant each expert.
@ -1097,12 +1091,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) replace_parameter(layer, "w13_weight", shuffled_w13)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) replace_parameter(layer, "w2_weight", shuffled_w2)
layer.w13_weight_scale = torch.nn.Parameter( replace_parameter(layer, "w13_weight_scale", max_w13_scales)
max_w13_scales, requires_grad=False
)
if self.flashinfer_moe_backend is not None: if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is # NOTE: weights have to be swapped since the activation is
@ -1245,41 +1237,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == "silu", ( if layer.enable_eplb:
f"Expected 'silu' activation but got {activation}" raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
) )
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = ( e_score_correction_bias = (
e_score_correction_bias.to(x.dtype) layer.e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None if layer.e_score_correction_bias is not None
else None else None
) )
routing_method_type = layer.routing_method_type routing_method_type = layer.routing_method_type
@ -1293,29 +1264,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale_inv=layer.w13_weight_scale_inv, w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight, w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv, w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition, intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
assert not renormalize and custom_routing_function is not None assert (
not layer.renormalize and layer.custom_routing_function is not None
)
result = apply_flashinfer_per_tensor_scale_fp8( result = apply_flashinfer_per_tensor_scale_fp8(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
routing_bias=e_score_correction_bias, routing_bias=layer.e_score_correction_bias,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
select_result = layer.select_experts( select_result = layer.select_experts(
@ -1336,13 +1309,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe( result = fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -1355,20 +1330,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation == "silu", ( assert layer.activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
if not self.block_quant: if not self.block_quant:
assert not renormalize and custom_routing_function is not None assert (
assert scoring_func == "sigmoid", ( not layer.renormalize and layer.custom_routing_function is not None
f"Expected 'sigmoid' scoring func but got {scoring_func}" )
assert layer.scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
) )
# Delegate to CUTLASS FlashInfer path; function already bound with # Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable # use_deepseek_fp8_block_scale for block-quant when applicable
@ -1378,10 +1355,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -1393,10 +1370,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=( allow_cutlass_block_scaled_grouped_gemm=(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
@ -625,26 +625,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input: if layer.apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
"fused GGUF MoE method." "fused GGUF MoE method."
@ -662,7 +645,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, layer.w2_qweight_type.weight_type,
activation, layer.activation,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional from typing import Any, Optional
@ -790,25 +789,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -829,9 +811,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
super().__init__() super().__init__()
self.static = static self.static = static
self.group_shape = group_shape self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0 self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
self.is_group_quant = group_shape.is_per_group() self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant: if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode" assert not static, "Group quantization only supports dynamic mode"
@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
use_per_token_if_dynamic=self.use_per_token_if_dynamic, use_per_token_if_dynamic=self.use_per_token_if_dynamic,
) )
def forward_hip(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
use_aiter_quant = (
not self.is_group_quant
and self.use_aiter
and scale_ub is None
and x.is_contiguous()
)
use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
)
use_aiter_per_token_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
)
if use_aiter_per_tensor_quant:
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -440,31 +439,14 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
) )

View File

@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# skip if there are no weights to process (for example, weight reloading)
if not hasattr(layer, "q_scale"):
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "prob_scale")
return
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint. # regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to # No need to process kv scales after loading if we are going to

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fnmatch import fnmatch from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -706,43 +705,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet." "EPLB not supported for `ModelOptFp8MoEMethod` yet."
) )
assert activation == "silu", ( assert layer.activation == "silu", (
f"Expected 'silu' activation but got {activation}" f"Expected 'silu' activation but got {layer.activation}"
) )
assert not renormalize
assert not layer.renormalize
return apply_flashinfer_per_tensor_scale_fp8( return apply_flashinfer_per_tensor_scale_fp8(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
routing_bias=e_score_correction_bias, routing_bias=layer.e_score_correction_bias,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
top_k=top_k, top_k=layer.top_k,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
# Expert selection # Expert selection
@ -752,9 +735,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation in ("silu", "relu2_no_mul"), ( assert layer.activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul')," "Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {activation}" f"but got {layer.activation}"
) )
return flashinfer_cutlass_moe_fp8( return flashinfer_cutlass_moe_fp8(
x, x,
@ -762,10 +745,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
@ -779,11 +762,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
@ -1503,23 +1486,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.moe.is_act_and_mul: if not self.moe.is_act_and_mul:
assert ( assert (
@ -1534,7 +1500,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet." "EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
) )
@ -1542,12 +1508,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer=layer, layer=layer,
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, top_k=layer.top_k,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
num_expert_group=num_expert_group, num_expert_group=layer.num_expert_group,
topk_group=topk_group, topk_group=layer.topk_group,
custom_routing_function=custom_routing_function, custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
@ -1570,9 +1536,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
) )
@ -1603,10 +1569,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
inplace=False, inplace=False,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
@ -1621,8 +1587,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO: derive from arguments # TODO: derive from arguments
m=x.shape[0], m=x.shape[0],
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -60,7 +59,7 @@ class MoeWNA16Config(QuantizationConfig):
if self.linear_quant_method == "gptq": if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = ( device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int() -1 if capability_tuple is None else capability_tuple.to_int()
@ -107,7 +106,7 @@ class MoeWNA16Config(QuantizationConfig):
if linear_quant_method == "gptq": if linear_quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"]) has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = [] modules_to_not_convert = []
elif linear_quant_method == "awq": elif linear_quant_method in ("awq", "awq_marlin"):
has_zp = cls.get_from_keys(config, ["zero_point"]) has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or( modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None config, ["modules_to_not_convert"], None
@ -184,7 +183,7 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(self.full_config).get_quant_method( return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
if self.use_marlin and check_marlin_supports_layer( if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size layer, self.group_size
): ):
@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -468,7 +450,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format # convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq": # awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4 assert layer.quant_config.weight_bits == 4
if "weight" in weight_name: if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight") loaded_weight = convert_awq_tensor(loaded_weight, "qweight")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_scale1=None, global_scale1=None,
global_scale2=None, global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
activation=activation, activation=layer.activation,
expert_map=expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(
use_grouped_topk, layer.use_grouped_topk,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
expert_map, layer.expert_map,
custom_routing_function, layer.custom_routing_function,
e_score_correction_bias, layer.e_score_correction_bias,
apply_router_weight_on_input, layer.apply_router_weight_on_input,
scoring_func, layer.scoring_func,
activation, layer.activation,
expert_load_view, layer.expert_load_view,
logical_to_physical_map, layer.logical_to_physical_map,
logical_replica_count, layer.logical_replica_count,
), "MXFP4 are not supported with this configuration." ), "MXFP4 are not supported with this configuration."
if ( if (
@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar None, # output1_scale_scalar
None, # output1_scale_gate_scalar None, # output1_scale_gate_scalar
None, # output2_scale_scalar None, # output2_scale_scalar
global_num_experts, layer.global_num_experts,
top_k, layer.top_k,
None, # n_group None, # n_group
None, # topk_group None, # topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts, # local num experts self.num_experts, # local num experts
None, None,
None, None,
1 if renormalize else 0, # routing_method_type, renormalize 1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=layer.top_k,
renormalize=renormalize, renormalize=layer.renormalize,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "swigluoai", ( assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE" "Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
) )
hidden_size_pad = round_up(self.original_hidden_size, 128) hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion( hidden_states = layer.ipex_fusion(
x_pad, x_pad,
use_grouped_topk, layer.use_grouped_topk,
top_k, layer.top_k,
router_logits, router_logits,
renormalize, layer.renormalize,
topk_group, layer.topk_group,
num_expert_group, layer.num_expert_group,
activation="swiglu_oai", activation="swiglu_oai",
) )
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
) )
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE." assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -631,9 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=layer.activation,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=expert_map, expert_map=layer.expert_map,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -645,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
return out return out

View File

@ -3,7 +3,6 @@
# Copyright © 2025, Oracle and/or its affiliates. # Copyright © 2025, Oracle and/or its affiliates.
import os import os
from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x,
@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_config.quant_type.id, quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
workspace=workspace, workspace=workspace,
) )

View File

@ -27,10 +27,10 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
@ -195,6 +195,39 @@ direct_register_custom_op(
) )
def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
return per_token_group_quant_fp8(
x, group_size, column_major_scales=False, use_ue8m0=False
)
def _triton_per_token_group_quant_fp8_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
direct_register_custom_op(
"triton_per_token_group_quant_fp8",
_triton_per_token_group_quant_fp8_impl,
fake_impl=_triton_per_token_group_quant_fp8_fake,
)
# TODO fix ROCm->Triton custom path: # TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp: class W8A8BlockFp8LinearOp:
@ -214,6 +247,7 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90) self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations. # Get the correct blockscale mul and input quant operations.
@ -269,7 +303,7 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0: if self.use_deep_gemm_e8m0 and self.is_blackwell:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d, input_2d,
group_size=self.act_quant_group_shape.col, group_size=self.act_quant_group_shape.col,
@ -340,17 +374,15 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None: if input_scale is not None:
q_input = input_2d q_input = input_2d
# MI350 case uses triton kernel
elif use_triton: elif use_triton:
q_input, input_scale = per_token_group_quant_fp8( q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d, input_2d,
self.act_quant_group_shape.col, self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
) )
# MI300 uses tuned AITER ASM/C++ kernel
else: else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
input_2d, self.act_quant_group_shape.col
)
return gemm_a8w8_blockscale_op( return gemm_a8w8_blockscale_op(
q_input, q_input,
@ -1404,12 +1436,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
if should_use_deepgemm: if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data, wq=layer.weight.data,
ws=layer.weight_scale.data, ws=layer.weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size), quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(), use_e8m0=is_deep_gemm_e8m0_used(),
) )
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False) replace_parameter(layer, "weight", dg_weight)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False) replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def expert_weight_is_col_major(x: torch.Tensor) -> bool:

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_quant_input, marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
size_n=part_size_n, size_n=part_size_n,
num_bits=8, num_bits=8,
) )
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) replace_parameter(layer, "weight", marlin_qweight)
# WEIGHT SCALES # WEIGHT SCALES
# Permute scales # Permute scales
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
scales = layer.weight_scale.to(layer.orig_dtype) scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer): elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype) scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
group_size = -1 if weight_block_size is None else weight_block_size[1] group_size = -1 if weight_block_size is None else weight_block_size[1]
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
) )
if input_dtype != torch.float8_e4m3fn: if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) if hasattr(layer, "weight_scale"):
replace_parameter(layer, "weight_scale", marlin_scales)
elif hasattr(layer, "weight_scale_inv"):
replace_parameter(layer, "weight_scale_inv", marlin_scales)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,) assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias) bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False) replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin(

View File

@ -95,8 +95,11 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since) # from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale. # we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
#
# Extra note: upon weight reloading weight_scale.ndim == 0
unfused_module_in_checkpoint = ( unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min weight_scale.ndim != 0
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
) )
# If unfused checkpoint, need requanize with the single scale. # If unfused checkpoint, need requanize with the single scale.

View File

@ -367,6 +367,8 @@ class Qwen2MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
@ -512,6 +514,12 @@ class Qwen2MoeModel(nn.Module):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
# GGUF: make sure that shared_expert_gate is a 2D tensor.
if (
"mlp.shared_expert_gate" in name
and len(loaded_weight.shape) == 1
):
loaded_weight = loaded_weight[None, :]
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader

View File

@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.config = config self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping() expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert loaded_weight.numel() == 1, (
f"KV scale numel {loaded_weight.numel()} != 1"
)
loaded_weight = loaded_weight.squeeze()
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:

View File

@ -50,6 +50,31 @@ def set_weight_attrs(
setattr(weight, key, value) setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
This function should not be called on weights which are tied/shared
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
"""
# should not be used on a tied/shared param
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
if old_param is not None and hasattr(old_param, "weight_loader"):
weight_loader = old_param.weight_loader
set_weight_attrs(new_param, {"weight_loader": weight_loader})
setattr(layer, param_name, new_param)
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None) parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}

View File

@ -381,6 +381,8 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
if compilation_config.cudagraph_mode.has_full_cudagraphs(): if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs # decode context parallel does not support full cudagraphs
@ -400,8 +402,6 @@ class RocmPlatform(Platform):
) )
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = 16
@ -415,6 +415,9 @@ class RocmPlatform(Platform):
): ):
compilation_config.custom_ops.append("+rms_norm") compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:

Some files were not shown because too many files have changed in this diff Show More