Merge remote-tracking branch 'origin/main' into refactor-fp8-linear

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-12-11 04:51:04 +00:00
commit 52e2a31a95
131 changed files with 2893 additions and 1627 deletions

View File

@ -398,7 +398,8 @@ steps:
timeout_in_minutes: 25
gpu: h100
source_file_dependencies:
- vllm/
- vllm/v1/attention
- vllm/model_executor/layers
- tests/v1/determinism/
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
@ -440,23 +441,29 @@ steps:
working_dir: "/vllm-workspace/examples"
source_file_dependencies:
- vllm/entrypoints
- vllm/multimodal
- examples/
commands:
- pip install tensorizer # for tensorizer test
# for basic
- python3 offline_inference/basic/chat.py
- python3 offline_inference/basic/generate.py --model facebook/opt-125m
- python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
- python3 offline_inference/basic/chat.py
- python3 offline_inference/prefix_caching.py
- python3 offline_inference/llm_engine_example.py
- python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py
# for multi-modal models
- python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py
# for pooling models
- python3 pooling/pooling/vision_language_pooling.py --seed 0
# for features demo
- python3 offline_inference/prefix_caching.py
- python3 offline_inference/llm_engine_example.py
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
# https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536
@ -718,6 +725,18 @@ steps:
- uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min
timeout_in_minutes: 75
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
autorun_on_main: true
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: OpenAI API correctness # 10min
timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction]
@ -727,7 +746,7 @@ steps:
- csrc/
- vllm/entrypoints/openai/
- vllm/model_executor/models/whisper.py
commands: # LMEval
commands: # LMEval+Transcription WER check
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
- pytest -s entrypoints/openai/correctness/
@ -963,6 +982,19 @@ steps:
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min
timeout_in_minutes: 180
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
- vllm/inputs/
- vllm/v1/core/
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: Multi-Modal Models Test (Extended) 1 # 60min
timeout_in_minutes: 120
mirror_hardwares: [amdexperimental]
@ -1098,7 +1130,6 @@ steps:
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- vllm/model_executor/layers/fused_moe/layer.py
- tests/compile/test_fusion_attn.py
- tests/compile/test_silu_mul_quant_fusion.py
- tests/compile/distributed/test_fusion_all_reduce.py
@ -1132,12 +1163,25 @@ steps:
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- tests/compile/distributed/test_fusions_e2e.py
- tests/compile/fullgraph/test_full_graph.py
commands:
- nvidia-smi
# Run all e2e fusion tests
- pytest -v -s tests/compile/distributed/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Blackwell Quantized MoE Test
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
@ -1155,6 +1199,16 @@ steps:
commands:
- pytest -s -v tests/quantization/test_blackwell_moe.py
- label: Blackwell LM Eval Small Models
timeout_in_minutes: 120
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
##### 1 GPU test #####
##### multi gpus test #####
@ -1397,6 +1451,39 @@ steps:
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s -x lora/test_mixtral.py
- label: LM Eval Large Models # optional
gpu: a100
optional: true
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H100 test #####
- label: LM Eval Large Models (H100) # optional
gpu: h100
optional: true
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
##### H200 test #####
- label: Distributed Tests (H200) # optional
mirror_hardwares: [amdexperimental]
@ -1440,29 +1527,6 @@ steps:
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: Blackwell LM Eval Small Models
timeout_in_minutes: 120
gpu: b200
optional: true # run on nightlies
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
- label: Multi-Modal Accuracy Eval (Small Models) # 10min
timeout_in_minutes: 70
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
- vllm/inputs/
- vllm/v1/core/
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: LM Eval Large Models (4 Card)
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4
@ -1478,21 +1542,6 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
- label: LM Eval Large Models (H100) # optional
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4
# grade: Blocking
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
commands:
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
- label: ROCm LM Eval Large Models (8 Card)
mirror_hardwares: [amdproduction]
agent_pool: mi325_8
@ -1517,6 +1566,20 @@ steps:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
##### RL Integration Tests #####
- label: Prime-RL Integration Test # 15min
mirror_hardwares: [amdexperimental]
agent_pool: mi325_2
# grade: Blocking
timeout_in_minutes: 30
optional: true
num_gpus: 2
working_dir: "/vllm-workspace"
source_file_dependencies:
- vllm/
- .buildkite/scripts/run-prime-rl-test.sh
commands:
- bash .buildkite/scripts/run-prime-rl-test.sh
- label: DeepSeek V2-Lite Accuracy
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4
@ -1550,17 +1613,26 @@ steps:
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
##### RL Integration Tests #####
- label: Prime-RL Integration Test # 15min
- label: DeepSeek V2-Lite Async EPLB Accuracy
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
agent_pool: mi325_2
agent_pool: mi325_4
# grade: Blocking
timeout_in_minutes: 30
gpu: h100
optional: true
num_gpus: 2
num_gpus: 4
working_dir: "/vllm-workspace"
source_file_dependencies:
- vllm/
- .buildkite/scripts/run-prime-rl-test.sh
commands:
- bash .buildkite/scripts/run-prime-rl-test.sh
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
# grade: Blocking
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040

View File

@ -468,7 +468,9 @@ steps:
# tests covered elsewhere.
# Use `find` to launch multiple instances of pytest so that
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
# However, find does not normally propagate error codes, so we combine it with xargs
# (using -0 for proper path handling)
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@ -482,7 +484,9 @@ steps:
# as it is a heavy test that is covered in other steps.
# Use `find` to launch multiple instances of pytest so that
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
- "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
# However, find does not normally propagate error codes, so we combine it with xargs
# (using -0 for proper path handling)
- "find compile/fullgraph -maxdepth 1 -name 'test_*.py' -not -name 'test_full_graph.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
- label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40

View File

@ -13,7 +13,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Set up Python
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0

View File

@ -12,7 +12,7 @@ jobs:
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v6.0.1
- uses: astral-sh/setup-uv@v7
with:

View File

@ -16,7 +16,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
with:
python-version: "3.12"

View File

@ -15,7 +15,7 @@ jobs:
actions: write
runs-on: ubuntu-latest
steps:
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
# Increasing this value ensures that changes to this workflow
# propagate to all issues and PRs in days rather than months

View File

@ -96,8 +96,9 @@ start_server() {
# This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
VLLM_SERVER_DEV_MODE=1 \
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else
# Start server without profiling
VLLM_SERVER_DEV_MODE=1 \

View File

@ -963,8 +963,7 @@ def create_argument_parser():
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--result-dir",

View File

@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
endif()
# Build ACL with CMake
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
set(CMAKE_BUILD_TYPE "Release")
set(ARM_COMPUTE_ARCH "armv8.2-a")
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
set(ARM_COMPUTE_BUILD_TESTING "OFF")
set(_cmake_config_cmd
${CMAKE_COMMAND} -G Ninja -B build
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF

View File

@ -186,7 +186,7 @@ struct AttentionMetadata {
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains:
// - flags: bool array to indicate wether the split is finished
// - flags: bool array to indicate whether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad {

View File

@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `SwapAB ? N : M -> M` since we dont support SwapAB
// with `SwapAB ? N : M -> M` since we don't support SwapAB
// clang-format off
template<class ProblemShape>
static bool

View File

@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) {
int nPrRnd = div1 * div2;
int rnds0 = N / nPrRnd;
nPrRnd -= div1 * 3;
int rnds3 = N / nPrRnd;
nPrRnd -= div1;
int rnds4 = N / nPrRnd;
nPrRnd -= div1;
int rnds5 = N / nPrRnd;
nPrRnd -= div1;
int rnds6 = N / nPrRnd;
nPrRnd -= div1;
int rnds7 = N / nPrRnd;
nPrRnd -= div1;
int rnds8 = N / nPrRnd;
nPrRnd -= div1;
int rnds9 = N / nPrRnd;
nPrRnd -= div1;
int rtn = div2;
if (rnds0 == rnds3) rtn = div2 - 3;
if (rnds0 == rnds4) rtn = div2 - 4;
if (rnds0 == rnds5) rtn = div2 - 5;
if (rnds0 == rnds6) rtn = div2 - 6;
if (rnds0 == rnds7) rtn = div2 - 7;
if (rnds0 == rnds8) rtn = div2 - 8;
if (rnds0 == rnds9) rtn = div2 - 9;
return rtn;
int rnds[13];
for (int i = 0; i < 13; i++) {
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
nPrRnd -= div1;
}
for (int i = 12; i >= 0; i--)
if (rnds[0] == rnds[i]) return (div2 - i);
}
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
} \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
switch (N_in) {
case 1:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
WVSPLIT_TILE(sYT, 1)
break;
case 2:
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
WVSPLIT_TILE(sYT, 2)
break;
case 3:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
WVSPLIT_TILE(sYT, 3)
break;
case 4:
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
WVSPLIT_TILE(sYT, 4)
break;
default:
throw std::runtime_error(

View File

@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
- [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][]
- [vllm.config.StructuredOutputsConfig][]
- [vllm.config.ProfilerConfig][]
- [vllm.config.ObservabilityConfig][]
- [vllm.config.KVTransferConfig][]
- [vllm.config.CompilationConfig][]

View File

@ -5,16 +5,15 @@
## Profile with PyTorch Profiler
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
- `torch_profiler_with_memory` to record memory, off by default
- `torch_profiler_with_stack` to enable recording stack information, on by default
- `torch_profiler_with_flops` to enable recording FLOPs, off by default
- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
#### OpenAI Server
```bash
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
vllm serve meta-llama/Llama-3.1-8B-Instruct
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
```
vllm bench command:
@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
```bash
# server
VLLM_TORCH_CUDA_PROFILE=1 \
nsys profile \
--trace-fork-before-exec=true \
--cuda-graph-trace=node \
--capture-range=cudaProfilerApi \
--capture-range-end repeat \
vllm serve meta-llama/Llama-3.1-8B-Instruct
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
# client
vllm bench serve \

View File

@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
```bash
# Example UCX configuration, adjust according to your enviroment
# Example UCX configuration, adjust according to your environment
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
```

View File

@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
return self.end_token_id in delta_token_ids
...
```

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
from vllm import LLM, SamplingParams
# enable torch profiler, can also be set on cmd line
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
# Sample prompts.
prompts = [
"Hello, my name is",
@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=1,
profiler_config={
"profiler": "torch",
"torch_profiler_dir": "./vllm_profile",
},
)
llm.start_profile()

View File

@ -75,7 +75,7 @@ torchgeo==0.7.0
mteb==2.1.2
# Data processing
xgrammar==0.1.27
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
# Test async scheduling
# Utilities

View File

@ -17,7 +17,6 @@ def test_compile():
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
@pytest.mark.xfail
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
"""Test that Qwen2.5-VL vision submodules are compiled.

View File

@ -80,6 +80,8 @@ def test_compile_ranges(use_fresh_inductor_cache):
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
@ -112,6 +114,8 @@ def test_compile_config_get_compile_ranges():
VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=compilation_config,
)
@ -134,6 +138,8 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
)
scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
)
torch.set_default_device("cuda")

View File

@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
import vllm.config
import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
@ -237,13 +241,85 @@ def _generate_kernel_groupshape_combinations():
KERNEL_GROUPSHAPE_COMBINATIONS = _generate_kernel_groupshape_combinations()
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, **kwargs):
super().__init__()
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = [
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
for _ in range(3)
]
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
self.eps = eps
self.w8a8_block_fp8_linear = [
TestBlockFP8Layer(
GroupShape(128, 128),
self.w[i],
self.wscale[i],
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
for i in range(3)
]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
x2 = self.w8a8_block_fp8_linear[0](y)
# make sure resid is used for replacement to work
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
x2, resid, self.norm_weight[1], self.eps
)
x3 = self.w8a8_block_fp8_linear[1](y2)
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
x3, resid, self.norm_weight[2], self.eps
)
x4 = self.w8a8_block_fp8_linear[2](y3)
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
x4, resid, self.norm_weight[3], self.eps
)
return y4
def ops_in_model_before(self):
return [
torch.ops.vllm.rocm_aiter_rms_norm,
torch.ops.vllm.rocm_aiter_group_fp8_quant,
]
def ops_in_model_before_partial(self):
return []
def ops_in_model_after(self):
return [
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
list(itertools.product([TestModel], [True, False], [True, False]))
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
@ -253,9 +329,13 @@ def test_fusion_rmsnorm_quant(
num_tokens,
eps,
kernel_groupshape,
model_class,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
@ -290,7 +370,14 @@ def test_fusion_rmsnorm_quant(
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
if model_class is TestRmsnormGroupFp8QuantModel:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
)
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
else:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
@ -325,7 +412,10 @@ def test_fusion_rmsnorm_quant(
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
if (
not enable_rms_norm_custom_op
and model_class is not TestRmsnormGroupFp8QuantModel
):
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7

View File

@ -5,9 +5,14 @@ import copy
import pytest
import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.inductor_pass import (
CallableInductorPass,
InductorPass,
pass_context,
)
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import ModelConfig, VllmConfig
from vllm.config.utils import Range
# dummy custom pass that doesn't inherit
@ -42,35 +47,37 @@ class ProperPass(InductorPass):
],
)
def test_pass_manager_uuid(callable):
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
# Set the pass context as PassManager uuid uses it
with pass_context(Range(start=1, end=8)):
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager()
pass_manager.configure(config)
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant
)
config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.fuse_norm_quant = (
not config2.compilation_config.pass_config.fuse_norm_quant
)
config2.compilation_config.pass_config.fuse_act_quant = (
not config2.compilation_config.pass_config.fuse_act_quant
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._aiter_ops import IS_AITER_FOUND
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS,
@ -24,6 +25,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Quant,
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return [FUSED_OPS[kNvfp4Quant]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
scale_hidden_size = (hidden_size + 128 - 1) // 128
self.wscale = torch.rand(
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
]
def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
model_class: type[
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@ -172,9 +216,15 @@ def test_fusion_silu_and_mul_quant(
)
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
@ -193,12 +243,14 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1
assert sum([p.matched_count for p in fusion_passes]) == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.responses_utils import (
construct_chat_message_with_tool_call,
_construct_single_message_from_response_item,
construct_chat_messages_with_tool_call,
convert_tool_responses_to_completions_format,
)
@ -42,7 +44,43 @@ class TestResponsesUtils:
assert result == {"type": "function", "function": input_tool}
def test_construct_chat_message_with_tool_call(self):
def test_construct_chat_messages_with_tool_call(self):
"""Test construction of chat messages with tool calls."""
reasoning_item = ResponseReasoningItem(
id="lol",
summary=[],
type="reasoning",
content=[
Content(
text="Leroy Jenkins",
type="reasoning_text",
)
],
encrypted_content=None,
status=None,
)
mcp_tool_item = ResponseFunctionToolCall(
id="mcp_123",
call_id="call_123",
type="function_call",
status="completed",
name="python",
arguments='{"code": "123+456"}',
)
input_items = [reasoning_item, mcp_tool_item]
messages = construct_chat_messages_with_tool_call(input_items)
assert len(messages) == 1
message = messages[0]
assert message["role"] == "assistant"
assert message["reasoning"] == "Leroy Jenkins"
assert message["tool_calls"][0]["id"] == "call_123"
assert message["tool_calls"][0]["function"]["name"] == "python"
assert (
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
)
def test_construct_single_message_from_response_item(self):
item = ResponseReasoningItem(
id="lol",
summary=[],
@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content=None,
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert formatted_item["reasoning"] == "Leroy Jenkins"
@ -74,7 +112,7 @@ class TestResponsesUtils:
status=None,
)
formatted_item = construct_chat_message_with_tool_call(item)
formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant"
assert (
formatted_item["reasoning"]
@ -88,7 +126,7 @@ class TestResponsesUtils:
output="1234",
status="completed",
)
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
formatted_item = _construct_single_message_from_response_item(tool_call_output)
assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp"
@ -102,7 +140,7 @@ class TestResponsesUtils:
status=None,
)
with pytest.raises(ValueError):
construct_chat_message_with_tool_call(item)
_construct_single_message_from_response_item(item)
output_item = ResponseOutputMessage(
id="msg_bf585bbbe3d500e0",
@ -119,6 +157,6 @@ class TestResponsesUtils:
type="message",
)
formatted_item = construct_chat_message_with_tool_call(output_item)
formatted_item = _construct_single_message_from_response_item(output_item)
assert formatted_item["role"] == "assistant"
assert formatted_item["content"] == "dongyi"

View File

@ -7,7 +7,8 @@ import math
import pytest
import torch
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
]
def get_attn_isa(
block_size: int | None = None,
dtype: torch.dtype | None = None,
):
if block_size and dtype:
return _get_attn_isa(dtype, block_size)
else:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
return "neon"
elif torch._C._cpu._is_amx_tile_supported():
return "amx"
else:
return "vec"
# rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", QTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["neon"])
@pytest.mark.skipif(
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
reason="Not an Arm CPU.",
)
def test_varlen_with_paged_kv_normal_neon(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_softcap(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [True])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_alibi(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [True])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
@pytest.mark.parametrize("isa", [get_attn_isa()])
def test_varlen_with_paged_kv_sink(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],

View File

@ -26,7 +26,14 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
devices = ["cpu"]
if current_platform.is_cuda():
devices.append("cuda")
if current_platform.is_rocm():
devices.append("hip")
@pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention

View File

@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1, 1))
return ref_out, ref_scale.view(1)
def native_w8a8_block_matmul(

View File

@ -54,6 +54,10 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_info = torch.finfo(current_platform.fp8_dtype())
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
@torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul():
# Test simple case where weight.shape % 128 != 0,
@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
assert rel_diff < 0.001
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),

View File

@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),

View File

@ -21,6 +21,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if not current_platform.is_cuda():
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import mimetypes
import os
@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion():
connector.fetch_image(broken_img)
@pytest.mark.flaky(reruns=3, reruns_delay=5)
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
}
)
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
try:
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
except (TimeoutError, asyncio.TimeoutError) as e:
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async

View File

@ -10,10 +10,14 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
MODELS = [
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
),
)
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
):
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
if method_cls is Fp8MoEMethod and weight_block_size is None:
pytest.skip(
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
"If this is your use case, consider using a restore function like #26327"
)
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
weight_block_size=weight_block_size,
)
if method_cls is Fp8LinearMethod:
layer = torch.nn.Linear(1, 1)
method = method_cls(config)
method.create_weights(
layer=layer,
input_size_per_partition=1,
output_partition_sizes=[1],
input_size=1,
output_size=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
else:
layer = FusedMoE(
num_experts=1,
top_k=1,
hidden_size=1,
intermediate_size=1,
)
method = method_cls(config, layer)
method.create_weights(
layer=layer,
num_experts=1,
hidden_size=1,
intermediate_size_per_partition=1,
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
method.use_marlin = use_marlin
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
for name, param in layer.named_parameters()
]
# test loading
for name, shape, _ in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)
# test reloading works after loading
# assuming that no reshaping occurred
for name, shape, original_weight_loader in original_metadata:
param = getattr(layer, name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert weight_loader is original_weight_loader
weight_loader(param, torch.zeros(shape)) # cannot use empty
method.process_weights_after_loading(layer)

View File

@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
is False
)
def test_is_reasoning_end_streaming(self, test_tokenizer):
"""Test the is_reasoning_end_streaming method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
assert (
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
is True
)
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
assert parser.is_reasoning_end_streaming([], []) is False
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id], [end_token_id]
)
is True
)
assert (
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
[2],
)
is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, 2], [2]
)
is False
)
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)

View File

@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
input_tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
assert parser.is_reasoning_end(input_ids) is True
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
# Test extract_content_ids returns all input_ids
assert parser.extract_content_ids(input_ids) == input_ids

View File

@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming(
"single_tool_weather",
"multiple_tool_calls",
"content_before_tool",
"complex",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming(
],
"bla",
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"",
),
],
)
def test_extract_tool_calls_streaming_one_chunk(

View File

@ -53,9 +53,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer
@ -1381,8 +1378,8 @@ class TestBlockFP8Layer:
"""
Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface.
This is a workaround until W8A8BlockFp8LinearOp has a similar API to
FP8ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
This is a workaround until W8A8BlockFp8LinearOp implements
ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
"""
def __init__(
@ -1390,20 +1387,21 @@ class TestBlockFP8Layer:
group_shape: GroupShape,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
):
self.kernel = None # For compatibility with TestFP8Layer interface
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
self.weight = weight
self.weight_scale = weight_scale
self.input_scale = input_scale
def forward(
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(

View File

@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,

View File

@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
# 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE

View File

@ -10,6 +10,7 @@ from utils import (
BACKENDS,
_extract_step_logprobs,
_random_prompt,
is_device_capability_below_90,
resolve_model_name,
skip_unsupported,
)
@ -17,6 +18,8 @@ from utils import (
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
@skip_unsupported
@pytest.mark.timeout(1000)
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_model_len=8192,
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use more realistic prompts for better token generation
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
prompt = "the capital of france is"
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use a few test prompts
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
# Enable for MOE models
# enable_expert_parallel=True,
)

View File

@ -11,8 +11,10 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
# Supports testing on Ampere and Ada Lovelace devices.
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
reason="Requires CUDA and >= Ampere (SM80)",
)
BACKENDS: list[str] = [
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
return t, inner.token_ids
return None, None
def is_device_capability_below_90() -> bool:
return not current_platform.has_device_capability(90)

View File

@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.platforms import current_platform
from vllm.sampling_params import StructuredOutputsParams
from vllm.v1.metrics.reader import Metric
@ -70,6 +71,18 @@ def test_without_spec_decoding(
(True, "uni", True, None, True),
]
if current_platform.is_rocm():
# On ROCm, Only test with structured_outputs (deterministic)
# and skip chunk_prefill (more variable).
test_configs = [
cfg
for cfg in test_configs
if not cfg[4] # skip chunk_prefill=True
]
test_sampling_params = [
p for p in test_sampling_params if p.get("structured_outputs") is not None
]
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
@dynamo_config.patch(cache_size_limit=16)
@ -117,13 +137,23 @@ def run_tests(
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m:
# avoid precision errors
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list, list]] = []
for n, (
@ -143,6 +173,7 @@ def run_tests(
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
)
outputs.append(test_results)
@ -172,17 +203,34 @@ def run_tests(
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
assert _all_logprobs_match(base_logprobs, test_logprobs)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
if (
base_acceptance_rate is not None
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=5e-2)
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else:
# Currently the reported acceptance rate is expected to be
@ -213,6 +261,7 @@ def run_test(
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
@ -231,6 +280,15 @@ def run_test(
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner(
model,
max_model_len=512,
@ -240,7 +298,7 @@ def run_test(
# enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float32", # avoid precision errors
dtype=dtype,
speculative_config=spec_config,
disable_log_stats=False,
**cache_arg,
@ -300,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool:
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()
and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)
)

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()

View File

@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 82.5% acceptance rate at the end.
assert last_accept_rate > 0.825
# Heuristic: expect at least 80.0% acceptance rate at the end.
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()

View File

@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,

View File

@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
return request
def test_should_fill_bitmask_with_enable_in_reasoning(

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.envs as envs
from vllm.profiler.gpu_profiler import WorkerProfiler
from vllm.config import ProfilerConfig
from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler):
@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes.
"""
def __init__(self):
def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0
self.stop_call_count = 0
self.should_fail_start = False
super().__init__()
super().__init__(profiler_config)
def _start(self) -> None:
if self.should_fail_start:
@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self.stop_call_count += 1
@pytest.fixture(autouse=True)
def reset_mocks():
"""Fixture to reset mocks and env variables before each test."""
envs.VLLM_PROFILER_DELAY_ITERS = 0
envs.VLLM_PROFILER_MAX_ITERS = 0
@pytest.fixture
def default_profiler_config():
return ProfilerConfig(
profiler="torch",
torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop():
def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
assert profiler._active is True
@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert profiler.stop_call_count == 1
def test_delayed_start():
def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start
profiler.start()
@ -71,10 +73,10 @@ def test_delayed_start():
assert profiler.start_call_count == 1
def test_max_iterations():
def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
@ -95,12 +97,11 @@ def test_max_iterations():
assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters():
def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
# Step 1
@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert profiler.stop_call_count == 1
def test_idempotency():
def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start
profiler.start()
@ -142,10 +143,10 @@ def test_idempotency():
assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive():
def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet
profiler.step()
@ -155,9 +156,9 @@ def test_step_inactive():
assert profiler.start_call_count == 0
def test_start_failure():
def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True
profiler.start()
@ -168,9 +169,9 @@ def test_start_failure():
assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown():
def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running
profiler.shutdown()
@ -182,10 +183,10 @@ def test_shutdown():
assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop():
def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
profiler.step()

View File

@ -9,6 +9,8 @@ import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
_FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool:
from importlib.util import find_spec
@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -467,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual)
def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import per_tensor_quant_hip
return per_tensor_quant_hip(x, scale, quant_dtype)
def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
1, dtype=torch.float32, device=x.device
)
def _rocm_aiter_per_token_quant_impl(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
out,
x,
scale,
scale_ub=None,
shuffle_scale=False,
num_rows=None,
num_rows_factor=1,
)
return out, scale
def _rocm_aiter_per_token_quant_fake(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@ -502,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
return cls.is_linear_enabled()
@classmethod
@if_aiter_supported
@ -577,14 +747,6 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
@ -644,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_per_tensor_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
@ -859,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale=kv_scale,
)
@staticmethod
def per_tensor_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
@staticmethod
def per_token_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,

View File

@ -1726,7 +1726,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub
)
else:
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"

View File

@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None

View File

@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy(
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)

View File

@ -12,7 +12,6 @@ from typing import Any
import numpy as np
from tqdm import tqdm
import vllm.envs as envs
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace):
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
raise OSError(
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
"Please set it to a valid path to use torch profiler."
)
engine_args = EngineArgs.from_cli_args(args)
if args.profile and not engine_args.profiler_config.profiler == "torch":
raise ValueError(
"The torch profiler is not enabled. Please provide profiler_config."
)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
profile_dir = engine_args.profiler_config.torch_profiler_dir
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return

View File

@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--save-result",

View File

@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--profile",
action="store_true",
default=False,
help="Use Torch Profiler. The env variable "
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
# prefix repetition dataset

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import functools
import hashlib
import inspect
@ -8,15 +10,17 @@ import json
import types
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.config.utils import Range
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
from vllm.config.utils import Range
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:

View File

@ -5,6 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

View File

@ -53,8 +53,27 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (
last_compile_range.end
== vllm_config.scheduler_config.max_num_batched_tokens
)
self.compile_ranges[-1] = Range(
start=last_compile_range.start, end=max_int32
)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)

View File

@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import empty_bf16
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
)
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
class AiterRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm & group fp8 quant custom
ops into an aiter rms_norm_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_GROUP_QUANT_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
return at[0], at[1]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class AiterFusedAddRMSFp8GroupQuantPattern:
"""
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
into a aiter rms_norm_with_add_group_fp8_quant op.
"""
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
self.epsilon = epsilon
self.quant_dtype = quant_dtype
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at1 = AITER_RMS_ADD_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
)
at2 = self.quant_op(at1[0], 128)
# result, scale, residual
return at2[0], at2[1], at1[1]
def replacement(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
):
at = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon,
group_size=128,
)
# result, scale, residual
return at[0], at[1], at[2]
inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + dynamic group fp8 quant
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
self.patterns
)
AiterFusedAddRMSFp8GroupQuantPattern(
epsilon, FP8_DTYPE, quant_op
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
fusion_patterns = [
AiterRMSFp8GroupQuantPattern,
AiterFusedAddRMSFp8GroupQuantPattern,
]
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
"""
def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)

View File

@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
@ -89,6 +90,8 @@ __all__ = [
"SpeechToTextConfig",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.profiler
"ProfilerConfig",
# From vllm.config.utils
"ConfigType",
"SupportsMetricsInfo",

199
vllm/config/profiler.py Normal file
View File

@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
profiler: ProfilerKind | None = None
"""Which profiler to use. Defaults to None. Options are:
- 'torch': Use PyTorch profiler.\n
- 'cuda': Use CUDA profiler."""
torch_profiler_dir: str = ""
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
torch_profiler_use_gzip: bool = True
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
torch_profiler_dump_cuda_time_total: bool = True
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
torch_profiler_record_shapes: bool = False
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
torch_profiler_with_memory: bool = False
"""If `True`, enables memory profiling in the torch profiler.
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
"""
delay_iterations: int = Field(default=0, ge=0)
"""Number of engine iterations to skip before starting profiling.
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
"""
max_iterations: int = Field(default=0, ge=0)
"""Maximum number of engine iterations to profile after starting profiling.
Defaults to 0, meaning no limit.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Get field from env var if set, with deprecation warning."""
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--profiler-config.%s command line argument or "
"ProfilerConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
return value
return None
def _set_from_env_if_set(
self,
field_name: str,
env_var_name: str,
to_bool: bool = True,
to_int: bool = False,
) -> None:
"""Set field from env var if set, with deprecation warning."""
value = self._get_from_env_if_set(field_name, env_var_name)
if value is not None:
if to_bool:
value = value == "1"
if to_int:
value = int(value)
setattr(self, field_name, value)
@model_validator(mode="after")
def _validate_profiler_config(self) -> Self:
maybe_use_cuda_profiler = self._get_from_env_if_set(
"profiler", "VLLM_TORCH_CUDA_PROFILE"
)
if maybe_use_cuda_profiler is not None:
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
else:
self._set_from_env_if_set(
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
)
if self.torch_profiler_dir:
self.profiler = "torch"
self._set_from_env_if_set(
"torch_profiler_record_shapes",
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
)
self._set_from_env_if_set(
"torch_profiler_with_memory",
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
)
self._set_from_env_if_set(
"torch_profiler_with_stack",
"VLLM_TORCH_PROFILER_WITH_STACK",
)
self._set_from_env_if_set(
"torch_profiler_with_flops",
"VLLM_TORCH_PROFILER_WITH_FLOPS",
)
self._set_from_env_if_set(
"ignore_frontend",
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
)
self._set_from_env_if_set(
"torch_profiler_use_gzip",
"VLLM_TORCH_PROFILER_USE_GZIP",
)
self._set_from_env_if_set(
"torch_profiler_dump_cuda_time_total",
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
)
self._set_from_env_if_set(
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
)
self._set_from_env_if_set(
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
)
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
logger.warning_once(
"Using 'torch' profiler with delay_iterations or max_iterations "
"while ignore_frontend is False may result in high overhead."
)
profiler_dir = self.torch_profiler_dir
if profiler_dir and self.profiler != "torch":
raise ValueError(
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
)
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
return self

View File

@ -39,6 +39,7 @@ from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
@ -218,6 +219,8 @@ class VllmConfig:
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
@ -296,6 +299,8 @@ class VllmConfig:
vllm_factors.append("None")
if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())
@ -1042,8 +1047,14 @@ class VllmConfig:
self.compilation_config.max_cudagraph_capture_size
)
if max_cudagraph_capture_size is None:
decode_query_len = 1
if (
self.speculative_config
and self.speculative_config.num_speculative_tokens
):
decode_query_len += self.speculative_config.num_speculative_tokens
max_cudagraph_capture_size = min(
self.scheduler_config.max_num_seqs * 2, 512
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

View File

@ -50,6 +50,7 @@ from vllm.config import (
ObservabilityConfig,
ParallelConfig,
PoolerConfig,
ProfilerConfig,
SchedulerConfig,
SpeculativeConfig,
StructuredOutputsConfig,
@ -532,6 +533,8 @@ class EngineArgs:
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")
kv_transfer_config: KVTransferConfig | None = None
kv_events_config: KVEventsConfig | None = None
@ -1164,7 +1167,7 @@ class EngineArgs:
vllm_group.add_argument(
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
)
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"]
)
@ -1782,6 +1785,7 @@ class EngineArgs:
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
profiler_config=self.profiler_config,
additional_config=self.additional_config,
optimization_level=self.optimization_level,
)

View File

@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
# These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256
MCP_PREFIX = "mcp_"

View File

@ -19,6 +19,7 @@ from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.parser.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
id=f"mcpo_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
# change this to a mcp_ function call
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
self.parser.response_messages[-1] = last_msg
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
elif last_msg.name == "web_search_preview":

View File

@ -20,6 +20,7 @@ from vllm.beam_search import (
from vllm.config import (
CompilationConfig,
PoolerConfig,
ProfilerConfig,
StructuredOutputsConfig,
is_init_field,
)
@ -211,6 +212,7 @@ class LLM:
structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig
| None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
@ -282,6 +284,20 @@ class LLM:
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
else:
profiler_config_instance = ProfilerConfig()
# warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1))
_distributed_executor_backend = kwargs.get("distributed_executor_backend")
@ -324,6 +340,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,

View File

@ -1339,6 +1339,7 @@ class OpenAIServing:
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(

View File

@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if not _is_pre_v11_tokeniser(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
if self.bot_token_id is None:
raise RuntimeError(
@ -148,23 +143,24 @@ class MistralToolParser(ToolParser):
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
if not self._is_pre_v11:
function_call_arr = []
for single_tool_content in model_output.split(self.bot_token):
matches = self.fn_name_regex.findall(single_tool_content)
if "{" not in single_tool_content:
continue
for match in matches:
fn_name = match[0]
args = match[1]
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:

View File

@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"mcp_{random_uuid()}",
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
type=f"{MCP_PREFIX}call",
status="completed",
output=message.output,
# TODO: support error output
@ -98,12 +99,63 @@ def construct_input_messages(
if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input})
else:
for item in request_input:
messages.append(construct_chat_message_with_tool_call(item))
input_messages = construct_chat_messages_with_tool_call(request_input)
messages.extend(input_messages)
return messages
def construct_chat_message_with_tool_call(
def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall):

View File

@ -5,7 +5,7 @@
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response
import vllm.envs as envs
from vllm.config import ProfilerConfig
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request):
def attach_router(app: FastAPI):
if envs.VLLM_TORCH_PROFILER_DIR:
profiler_config = getattr(app.state.args, "profiler_config", None)
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
if profiler_config is not None and profiler_config.profiler is not None:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
"Profiler with mode '%s' is enabled in the "
"API server. This should ONLY be used for local development!",
profiler_config.profiler,
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
app.include_router(router)

View File

@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
@ -88,20 +89,23 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_TORCH_CUDA_PROFILE: bool = False
# Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
VLLM_PROFILER_DELAY_ITERS: str | None = None
VLLM_PROFILER_MAX_ITERS: str | None = None
# End of deprecated env variables for profiling
VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_PROFILER_DELAY_ITERS: int = 0
VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
@ -452,6 +456,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Main CUDA version of vLLM. This follows PyTorch but can be overridden.
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION",
"highest",
["highest", "high", "medium"],
case_sensitive=False,
),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
@ -841,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None
),
# Enables torch CUDA profiling if set.
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
),
# Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
# Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR": lambda: (
None
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None
else (
val
if val.startswith("gs://") and val[5:] and val[5] != "/"
else os.path.abspath(os.path.expanduser(val))
)
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
# Enable torch profiler to record shapes if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
),
# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
# not record shapes.
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"
# Enable torch profiler to profile memory if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
),
# Enable torch profiler to profile memory if set
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
# will not profile memory.
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"
# Enable torch profiler to profile stack if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
),
# Enable torch profiler to profile stack if set
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
# profile stack by default.
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
# Enable torch profiler to profile flops if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
),
# Enable torch profiler to profile flops if set
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
# not profile flops.
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
# Disable torch profiling of the AsyncLLMEngine process.
# If set to 1, will not profile the engine process.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
# Disable torch profiling of the AsyncLLMEngine process if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
),
# Delay number of iterations before starting profiling when using
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
),
# Deprecated, see profiler_config.
"VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
# If set to 0, will not limit the number of iterations.
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
"VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
# Control whether torch profiler gzip-compresses profiling files.
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default).
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
),
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
# Control whether torch profiler dumps the self_cuda_time_total table.
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping
# (enabled by default).
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0"
# Set to 0 to disable dumping the table.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
),
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

View File

@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None:
assert ubatch_slices is None
assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp(
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,

View File

@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if current_platform.is_device_capability(100):
if (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")

View File

@ -4,7 +4,10 @@
from contextlib import contextmanager
from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
@ -49,6 +52,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE",
"activation_without_mul",
"override_config",

View File

@ -895,6 +895,48 @@ def get_moe_configs(
return None
def _ensure_block_size_k_divisible(
size_k: int, block_size_k: int, group_size: int
) -> int:
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
Args:
size_k: The size_k dimension that must be divisible by result.
block_size_k: Preferred block size (will be adjusted if needed).
group_size: The result must be divisible by this.
Returns:
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
"""
# Fast path: already valid
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
return block_size_k
# Find the largest value that:
# 1. Divides size_k (size_k % candidate == 0)
# 2. Is divisible by group_size (candidate % group_size == 0)
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
#
# Strategy: Search from min(block_size_k, size_k) down to group_size,
# stepping by group_size to ensure divisibility by group_size
max_search = min(block_size_k, size_k)
start = (max_search // group_size) * group_size
for candidate in range(start, group_size - 1, -group_size):
if size_k % candidate == 0:
return candidate
# Fallback: if group_size divides size_k, use it
# This should always be true with correct group_size configuration
if size_k % group_size == 0:
return group_size
# This should not happen with correct group_size, but ensure divisibility
return size_k
def get_moe_wna16_block_config(
config: dict[str, int],
use_moe_wna16_cuda: bool,
@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
# at the same time.
block_size_n = 1024
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Callable
import torch
@ -100,22 +99,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
@ -97,23 +96,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
@ -127,10 +109,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else expert_map,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:

View File

@ -33,10 +33,6 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
@ -57,11 +53,8 @@ from vllm.utils.torch_utils import (
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record, fused_experts
from .fused_moe import eplb_map_to_physical_and_record
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = object # type: ignore
FusedMoEPrepareAndFinalize = object # type: ignore
def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
@ -483,7 +476,7 @@ class FusedMoE(CustomOp):
enable_eplb=self.enable_eplb,
)
self.expert_map: torch.Tensor | None
self._expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
@ -493,7 +486,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
logger.info_once(
@ -506,10 +499,10 @@ class FusedMoE(CustomOp):
self.expert_placement_strategy,
self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map),
get_compressed_expert_map(self._expert_map),
)
else:
self.local_num_experts, self.expert_map, self.expert_mask = (
self.local_num_experts, self._expert_map, self.expert_mask = (
self.global_num_experts,
None,
None,
@ -781,7 +774,7 @@ class FusedMoE(CustomOp):
),
)
if self.expert_map is None:
if self._expert_map is None:
return None
routing_tables = self.ensure_round_robin_expert_routing_tables(
@ -789,7 +782,7 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts,
device=self.expert_map.device,
device=self._expert_map.device,
)
global_to_physical, physical_to_global, local_global = routing_tables
@ -840,8 +833,8 @@ class FusedMoE(CustomOp):
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
with self.expert_map.device:
assert self._expert_map is not None
with self._expert_map.device:
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
@ -851,7 +844,7 @@ class FusedMoE(CustomOp):
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("_expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled:
@ -888,7 +881,7 @@ class FusedMoE(CustomOp):
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
@ -1068,9 +1061,9 @@ class FusedMoE(CustomOp):
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
if self._expert_map is None:
return expert_id
return self.expert_map[expert_id].item()
return self._expert_map[expert_id].item()
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
@ -1744,6 +1737,12 @@ class FusedMoE(CustomOp):
reduce_output(fused_output)[..., :og_hidden_states],
)
@property
def expert_map(self) -> torch.Tensor | None:
return (
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
@ -1805,24 +1804,6 @@ class FusedMoE(CustomOp):
layer=self,
x=staged_hidden_states,
router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if has_separate_shared_experts:
@ -1968,25 +1949,6 @@ class FusedMoE(CustomOp):
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if has_separate_shared_experts:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch.nn.functional as F
@ -269,53 +268,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: torch.nn.Module,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
return self.forward(
x=x,
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_fused_moe_quant_config(
@ -333,24 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
@ -364,9 +307,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=layer.expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
@ -375,8 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
result = fused_experts(
@ -386,11 +329,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
activation=layer.activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
@ -405,148 +348,101 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe(
layer,
x,
use_grouped_topk,
top_k,
layer.use_grouped_topk,
layer.top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
layer.global_num_experts,
layer.expert_map,
layer.custom_routing_function,
layer.scoring_func,
layer.routed_scaling_factor,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
def forward_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
layer.use_grouped_topk,
layer.top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
custom_routing_function=layer.custom_routing_function,
)
def forward_tpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
assert not layer.use_grouped_topk
assert layer.num_expert_group is None
assert layer.topk_group is None
assert layer.custom_routing_function is None
assert layer.apply_router_weight_on_input is False
if layer.scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU."
)
if e_score_correction_bias is not None:
if layer.e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU."
)
assert activation == "silu", f"{activation} is not supported for TPU."
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU."
assert layer.activation == "silu", (
f"{layer.activation} is not supported for TPU."
)
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
)
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
topk=layer.top_k,
gating_output=router_logits,
global_num_experts=global_num_experts,
expert_map=expert_map,
renormalize=renormalize,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
renormalize=layer.renormalize,
)
if current_platform.is_tpu():

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -669,25 +668,8 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -708,9 +690,9 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Union
import torch
@ -498,23 +497,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -534,10 +516,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from collections.abc import Callable
from enum import Enum
import torch
@ -558,31 +557,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
assert layer.activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if enable_eplb:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
)
@ -591,12 +573,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
global_num_experts=global_num_experts,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
e_score_correction_bias=e_score_correction_bias,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
topk_weights, topk_ids, _ = layer.select_experts(
@ -619,9 +601,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
@ -646,15 +628,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert expert_map is None, (
assert layer.expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod."
@ -670,7 +652,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
@ -1188,23 +1170,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -1215,7 +1180,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe(
x,
layer.w13_weight,
@ -1228,9 +1195,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
@ -1248,9 +1215,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -1270,10 +1237,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=None
if self.disable_expert_map
else layer.expert_map, # ???
quant_config=self.moe_quant_config,
)
else:
@ -1290,9 +1259,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else layer.expert_map,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
@ -1314,10 +1283,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -1437,23 +1406,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -1469,10 +1421,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -1814,25 +1766,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -1853,9 +1790,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
@ -2057,23 +1994,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -2089,10 +2009,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -2372,32 +2292,15 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in ("silu", "swigluoai", "swiglu"), (
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert layer.activation in ("silu", "swigluoai", "swiglu"), (
"Only SiLU/SwiGLUGU/SwiGLUUG are supported."
)
assert expert_map is None, """expert_map/EP not implemented
assert layer.expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int:
@ -2414,15 +2317,9 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
top_k=layer.top_k,
use_grouped_topk=layer.use_grouped_topk,
renormalize=layer.renormalize,
)
return torch.ops._C.dynamic_4bit_int_moe(
@ -2435,8 +2332,8 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_in_features,
layer.w13_out_features,
layer.group_size,
apply_router_weight_on_input,
int(_act_kind(activation)),
layer.apply_router_weight_on_input,
int(_act_kind(layer.activation)),
)
@ -2707,28 +2604,11 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
):
if enable_eplb:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
@ -2749,9 +2629,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=None if self.disable_expert_map else layer.expert_map,
a_strides1=self.a_strides1_c_strides2,
a_strides2=self.a_strides2,
b_strides1=self.b_strides1,

View File

@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
# don't restrict as emulations
return 80
def create_weights(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional
import torch
@ -140,23 +139,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -172,10 +154,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Optional
@ -99,7 +98,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
@ -559,46 +558,50 @@ class Fp8LinearMethod(LinearMethodBase):
assert not self.act_q_static
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
)
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter
del layer.weight_scale_inv
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
# If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
# Update layer with new values.
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
layer.input_scale = (
Parameter(input_scale, requires_grad=False)
if input_scale is not None
else None
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = (
process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if input_scale is not None:
replace_parameter(layer, "input_scale", input_scale)
else:
layer.input_scale = None
if self.use_marlin:
prepare_fp8_layer_for_marlin(
@ -625,7 +628,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@ -654,10 +657,15 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
@ -671,7 +679,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@ -941,22 +949,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
w2_weight_scale_inv, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
@ -994,13 +998,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
@ -1009,16 +1014,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
@ -1039,12 +1045,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
@ -1058,22 +1060,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
@ -1097,12 +1091,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is
@ -1245,41 +1237,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
@ -1293,29 +1264,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor,
routed_scaling=layer.routed_scaling_factor,
)
else:
assert not renormalize and custom_routing_function is not None
assert (
not layer.renormalize and layer.custom_routing_function is not None
)
result = apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
select_result = layer.select_experts(
@ -1336,13 +1309,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
result = fused_marlin_moe(
x,
layer.w13_weight,
@ -1355,20 +1330,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
if not self.block_quant:
assert not renormalize and custom_routing_function is not None
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
assert (
not layer.renormalize and layer.custom_routing_function is not None
)
assert layer.scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
)
# Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable
@ -1378,10 +1355,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -1393,10 +1370,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from types import MappingProxyType
from typing import Any, Optional
@ -625,26 +625,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method."
@ -662,7 +645,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type,
activation,
layer.activation,
)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Optional
@ -790,25 +789,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert activation == "silu", "Only SiLU activation is supported."
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -829,9 +811,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
super().__init__()
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
def forward_hip(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
use_aiter_quant = (
not self.is_group_quant
and self.use_aiter
and scale_ub is None
and x.is_contiguous()
)
use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
)
use_aiter_per_token_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
)
if use_aiter_per_tensor_quant:
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)
def forward_native(
self,
x: torch.Tensor,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional
import torch
@ -440,31 +439,14 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
layer.use_grouped_topk,
layer.top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
custom_routing_function=layer.custom_routing_function,
)

View File

@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# skip if there are no weights to process (for example, weight reloading)
if not hasattr(layer, "q_scale"):
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "prob_scale")
return
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional
@ -706,43 +705,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
assert not renormalize
assert not layer.renormalize
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
# Expert selection
@ -752,9 +735,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert activation in ("silu", "relu2_no_mul"), (
assert layer.activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {activation}"
f"but got {layer.activation}"
)
return flashinfer_cutlass_moe_fp8(
x,
@ -762,10 +745,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
@ -779,11 +762,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
activation=layer.activation,
quant_config=self.moe_quant_config,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@ -1503,23 +1486,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.moe.is_act_and_mul:
assert (
@ -1534,7 +1500,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if enable_eplb:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
@ -1542,12 +1508,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
global_num_experts=global_num_experts,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
e_score_correction_bias=e_score_correction_bias,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
topk_weights, topk_ids, _ = layer.select_experts(
@ -1570,9 +1536,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
)
@ -1603,10 +1569,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
@ -1621,8 +1587,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO: derive from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional
import torch
@ -60,7 +59,7 @@ class MoeWNA16Config(QuantizationConfig):
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq":
elif self.linear_quant_method in ("awq", "awq_marlin"):
capability_tuple = current_platform.get_device_capability()
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
@ -107,7 +106,7 @@ class MoeWNA16Config(QuantizationConfig):
if linear_quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = []
elif linear_quant_method == "awq":
elif linear_quant_method in ("awq", "awq_marlin"):
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
@ -184,7 +183,7 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
elif self.linear_quant_method == "awq":
elif self.linear_quant_method in ("awq", "awq_marlin"):
if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size
):
@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -468,7 +450,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
# awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum
from typing import Optional
@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
)
assert _can_support_mxfp4(
use_grouped_topk,
topk_group,
num_expert_group,
expert_map,
custom_routing_function,
e_score_correction_bias,
apply_router_weight_on_input,
scoring_func,
activation,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.expert_load_view,
layer.logical_to_physical_map,
layer.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (
@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
global_num_experts,
top_k,
layer.global_num_experts,
layer.top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts, # local num experts
None,
None,
1 if renormalize else 0, # routing_method_type, renormalize
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert activation == "swigluoai", (
assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
)
hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion(
x_pad,
use_grouped_topk,
top_k,
layer.use_grouped_topk,
layer.top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
activation="swiglu_oai",
)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
expert_map=expert_map,
expert_map=layer.expert_map,
)
elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
return fused_marlin_moe(
x,
layer.w13_weight,
@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -631,9 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=expert_map,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@ -645,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
return out

View File

@ -3,7 +3,6 @@
# Copyright © 2025, Oracle and/or its affiliates.
import os
from collections.abc import Callable
from typing import Any, Optional
import numpy as np
@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
workspace=workspace,
)

View File

@ -27,10 +27,10 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
@ -195,6 +195,39 @@ direct_register_custom_op(
)
def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
return per_token_group_quant_fp8(
x, group_size, column_major_scales=False, use_ue8m0=False
)
def _triton_per_token_group_quant_fp8_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
direct_register_custom_op(
"triton_per_token_group_quant_fp8",
_triton_per_token_group_quant_fp8_impl,
fake_impl=_triton_per_token_group_quant_fp8_fake,
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
@ -214,6 +247,7 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations.
@ -269,7 +303,7 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
if self.use_deep_gemm_e8m0 and self.is_blackwell:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
@ -340,17 +374,15 @@ class W8A8BlockFp8LinearOp:
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
elif use_triton:
q_input, input_scale = per_token_group_quant_fp8(
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
input_2d, self.act_quant_group_shape.col
)
return gemm_a8w8_blockscale_op(
q_input,
@ -1404,12 +1436,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
if should_use_deepgemm:
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=layer.weight_scale.data,
ws=layer.weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
replace_parameter(layer, "weight", marlin_qweight)
# WEIGHT SCALES
# Permute scales
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
group_size = -1 if weight_block_size is None else weight_block_size[1]
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
)
if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "weight_scale"):
replace_parameter(layer, "weight_scale", marlin_scales)
elif hasattr(layer, "weight_scale_inv"):
replace_parameter(layer, "weight_scale_inv", marlin_scales)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
replace_parameter(layer, "bias", bias)
def prepare_moe_fp8_layer_for_marlin(

View File

@ -95,8 +95,11 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
#
# Extra note: upon weight reloading weight_scale.ndim == 0
unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
weight_scale.ndim != 0
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
)
# If unfused checkpoint, need requanize with the single scale.

View File

@ -367,6 +367,8 @@ class Qwen2MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
@ -512,6 +514,12 @@ class Qwen2MoeModel(nn.Module):
continue
else:
name = remapped_kv_scale_name
# GGUF: make sure that shared_expert_gate is a 2D tensor.
if (
"mlp.shared_expert_gate" in name
and len(loaded_weight.shape) == 1
):
loaded_weight = loaded_weight[None, :]
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader

View File

@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert loaded_weight.numel() == 1, (
f"KV scale numel {loaded_weight.numel()} != 1"
)
loaded_weight = loaded_weight.squeeze()
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:

View File

@ -50,6 +50,31 @@ def set_weight_attrs(
setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
This function should not be called on weights which are tied/shared
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
"""
# should not be used on a tied/shared param
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
if old_param is not None and hasattr(old_param, "weight_loader"):
weight_loader = old_param.weight_loader
set_weight_attrs(new_param, {"weight_loader": weight_loader})
setattr(layer, param_name, new_param)
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}

View File

@ -381,6 +381,8 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
@ -400,8 +402,6 @@ class RocmPlatform(Platform):
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
@ -415,6 +415,9 @@ class RocmPlatform(Platform):
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS:

Some files were not shown because too many files have changed in this diff Show More