mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 09:37:03 +08:00
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
commit
52e2a31a95
@ -398,7 +398,8 @@ steps:
|
||||
timeout_in_minutes: 25
|
||||
gpu: h100
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- vllm/v1/attention
|
||||
- vllm/model_executor/layers
|
||||
- tests/v1/determinism/
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -440,23 +441,29 @@ steps:
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
- vllm/multimodal
|
||||
- examples/
|
||||
commands:
|
||||
- pip install tensorizer # for tensorizer test
|
||||
# for basic
|
||||
- python3 offline_inference/basic/chat.py
|
||||
- python3 offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
- python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
|
||||
- python3 offline_inference/basic/chat.py
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
# for pooling models
|
||||
- python3 pooling/pooling/vision_language_pooling.py --seed 0
|
||||
# for features demo
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||
# https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536
|
||||
@ -718,6 +725,18 @@ steps:
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
timeout_in_minutes: 75
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
autorun_on_main: true
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness # 10min
|
||||
timeout_in_minutes: 15
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
@ -727,7 +746,7 @@ steps:
|
||||
- csrc/
|
||||
- vllm/entrypoints/openai/
|
||||
- vllm/model_executor/models/whisper.py
|
||||
commands: # LMEval
|
||||
commands: # LMEval+Transcription WER check
|
||||
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
@ -963,6 +982,19 @@ steps:
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- vllm/multimodal/
|
||||
- vllm/inputs/
|
||||
- vllm/v1/core/
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||
timeout_in_minutes: 120
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -1098,7 +1130,6 @@ steps:
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- vllm/model_executor/layers/fused_moe/layer.py
|
||||
- tests/compile/test_fusion_attn.py
|
||||
- tests/compile/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||
@ -1132,12 +1163,25 @@ steps:
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/distributed/test_fusions_e2e.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
# Run all e2e fusion tests
|
||||
- pytest -v -s tests/compile/distributed/test_fusions_e2e.py
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
optional: true # run on nightlies
|
||||
source_file_dependencies:
|
||||
- tests/evals/gpt_oss
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
|
||||
- label: Blackwell Quantized MoE Test
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
@ -1155,6 +1199,16 @@ steps:
|
||||
commands:
|
||||
- pytest -s -v tests/quantization/test_blackwell_moe.py
|
||||
|
||||
- label: Blackwell LM Eval Small Models
|
||||
timeout_in_minutes: 120
|
||||
gpu: b200
|
||||
optional: true # run on nightlies
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -1397,6 +1451,39 @@ steps:
|
||||
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
||||
|
||||
- label: LM Eval Large Models # optional
|
||||
gpu: a100
|
||||
optional: true
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
##### H100 test #####
|
||||
- label: LM Eval Large Models (H100) # optional
|
||||
gpu: h100
|
||||
optional: true
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
|
||||
|
||||
|
||||
##### H200 test #####
|
||||
- label: Distributed Tests (H200) # optional
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -1440,29 +1527,6 @@ steps:
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: Blackwell LM Eval Small Models
|
||||
timeout_in_minutes: 120
|
||||
gpu: b200
|
||||
optional: true # run on nightlies
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||
|
||||
- label: Multi-Modal Accuracy Eval (Small Models) # 10min
|
||||
timeout_in_minutes: 70
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- vllm/multimodal/
|
||||
- vllm/inputs/
|
||||
- vllm/v1/core/
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
|
||||
- label: LM Eval Large Models (4 Card)
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_4
|
||||
@ -1478,21 +1542,6 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: LM Eval Large Models (H100) # optional
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
|
||||
|
||||
- label: ROCm LM Eval Large Models (8 Card)
|
||||
mirror_hardwares: [amdproduction]
|
||||
agent_pool: mi325_8
|
||||
@ -1517,6 +1566,20 @@ steps:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
|
||||
##### RL Integration Tests #####
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_2
|
||||
# grade: Blocking
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- .buildkite/scripts/run-prime-rl-test.sh
|
||||
commands:
|
||||
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||
- label: DeepSeek V2-Lite Accuracy
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_4
|
||||
@ -1550,17 +1613,26 @@ steps:
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
##### RL Integration Tests #####
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_2
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
timeout_in_minutes: 30
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- .buildkite/scripts/run-prime-rl-test.sh
|
||||
commands:
|
||||
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
||||
|
||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
||||
timeout_in_minutes: 60
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@ -468,7 +468,9 @@ steps:
|
||||
# tests covered elsewhere.
|
||||
# Use `find` to launch multiple instances of pytest so that
|
||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
||||
# However, find does not normally propagate error codes, so we combine it with xargs
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
@ -482,7 +484,9 @@ steps:
|
||||
# as it is a heavy test that is covered in other steps.
|
||||
# Use `find` to launch multiple instances of pytest so that
|
||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||
- "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
|
||||
# However, find does not normally propagate error codes, so we combine it with xargs
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/fullgraph -maxdepth 1 -name 'test_*.py' -not -name 'test_full_graph.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
|
||||
- label: PyTorch Fullgraph Test # 27min
|
||||
timeout_in_minutes: 40
|
||||
|
||||
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
|
||||
|
||||
2
.github/workflows/macos-smoke-test.yml
vendored
2
.github/workflows/macos-smoke-test.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v6.0.1
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
|
||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
|
||||
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
actions: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
# Increasing this value ensures that changes to this workflow
|
||||
# propagate to all issues and PRs in days rather than months
|
||||
|
||||
@ -96,8 +96,9 @@ start_server() {
|
||||
# This correctly passes each element as a separate argument.
|
||||
if [[ -n "$profile_dir" ]]; then
|
||||
# Start server with profiling enabled
|
||||
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
|
||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
|
||||
VLLM_SERVER_DEV_MODE=1 \
|
||||
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
else
|
||||
# Start server without profiling
|
||||
VLLM_SERVER_DEV_MODE=1 \
|
||||
|
||||
@ -963,8 +963,7 @@ def create_argument_parser():
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-dir",
|
||||
|
||||
@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
||||
endif()
|
||||
|
||||
# Build ACL with CMake
|
||||
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
set(ARM_COMPUTE_ARCH "armv8.2-a")
|
||||
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
|
||||
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
|
||||
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
|
||||
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
|
||||
set(ARM_COMPUTE_BUILD_TESTING "OFF")
|
||||
|
||||
set(_cmake_config_cmd
|
||||
${CMAKE_COMMAND} -G Ninja -B build
|
||||
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
||||
|
||||
@ -186,7 +186,7 @@ struct AttentionMetadata {
|
||||
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
|
||||
// * q_tile_size * 4, partial output, max + sum (float)
|
||||
// Reduction scratchpad contains:
|
||||
// - flags: bool array to indicate wether the split is finished
|
||||
// - flags: bool array to indicate whether the split is finished
|
||||
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
|
||||
// - max, sum: 2 * split_num * q_tile_size * 4
|
||||
class AttentionScratchPad {
|
||||
|
||||
@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
|
||||
|
||||
// Same as upstream, should be kept the same when possible, not formatted for
|
||||
// easier comparison
|
||||
// with `SwapAB ? N : M -> M` since we dont support SwapAB
|
||||
// with `SwapAB ? N : M -> M` since we don't support SwapAB
|
||||
// clang-format off
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
|
||||
@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
||||
}
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
// Find the min val of div2 that doesn't increase N/(div1*div2)
|
||||
int mindiv(int N, int div1, int div2) {
|
||||
int nPrRnd = div1 * div2;
|
||||
int rnds0 = N / nPrRnd;
|
||||
nPrRnd -= div1 * 3;
|
||||
int rnds3 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds4 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds5 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds6 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds7 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds8 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rnds9 = N / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
int rtn = div2;
|
||||
if (rnds0 == rnds3) rtn = div2 - 3;
|
||||
if (rnds0 == rnds4) rtn = div2 - 4;
|
||||
if (rnds0 == rnds5) rtn = div2 - 5;
|
||||
if (rnds0 == rnds6) rtn = div2 - 6;
|
||||
if (rnds0 == rnds7) rtn = div2 - 7;
|
||||
if (rnds0 == rnds8) rtn = div2 - 8;
|
||||
if (rnds0 == rnds9) rtn = div2 - 9;
|
||||
return rtn;
|
||||
int rnds[13];
|
||||
for (int i = 0; i < 13; i++) {
|
||||
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
|
||||
nPrRnd -= div1;
|
||||
}
|
||||
for (int i = 12; i >= 0; i--)
|
||||
if (rnds[0] == rnds[i]) return (div2 - i);
|
||||
}
|
||||
|
||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
|
||||
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} \
|
||||
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
||||
{ \
|
||||
dim3 block(64, 16); \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
else if (K_in * N_in <= max_lds_len * 1.2) \
|
||||
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
else \
|
||||
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
}
|
||||
|
||||
#define WVSPLIT_TILE(_sYT, __N) \
|
||||
{ \
|
||||
bool fit_lds = (K_in * N_in <= max_lds_len); \
|
||||
if (_sYT <= 1) \
|
||||
WVSPLITK(1, 4, __N) \
|
||||
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
||||
WVSPLITK(2, 2, __N) \
|
||||
else if (_sYT <= 4 * 3) \
|
||||
WVSPLITK(3, 2, __N) \
|
||||
else if (__N == 4) \
|
||||
WVSPLITK(4, 1, __N) \
|
||||
else \
|
||||
WVSPLITK(4, 2, __N) \
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
||||
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
|
||||
// first shoot for biggest tile-size that keeps all simd busy,
|
||||
// then cut the active waves to balance their distribution...
|
||||
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
|
||||
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
|
||||
WVSPLIT_TILE(sYT, 1)
|
||||
break;
|
||||
case 2:
|
||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
|
||||
WVSPLIT_TILE(sYT, 2)
|
||||
break;
|
||||
case 3:
|
||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
|
||||
WVSPLIT_TILE(sYT, 3)
|
||||
break;
|
||||
case 4:
|
||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
|
||||
WVSPLIT_TILE(sYT, 4)
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.StructuredOutputsConfig][]
|
||||
- [vllm.config.ProfilerConfig][]
|
||||
- [vllm.config.ObservabilityConfig][]
|
||||
- [vllm.config.KVTransferConfig][]
|
||||
- [vllm.config.CompilationConfig][]
|
||||
|
||||
@ -5,16 +5,15 @@
|
||||
|
||||
## Profile with PyTorch Profiler
|
||||
|
||||
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
|
||||
We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
|
||||
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
|
||||
|
||||
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
|
||||
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
|
||||
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
|
||||
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
|
||||
|
||||
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
|
||||
- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
|
||||
- `torch_profiler_with_memory` to record memory, off by default
|
||||
- `torch_profiler_with_stack` to enable recording stack information, on by default
|
||||
- `torch_profiler_with_flops` to enable recording FLOPs, off by default
|
||||
- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
|
||||
- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
|
||||
|
||||
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
|
||||
|
||||
@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
|
||||
#### OpenAI Server
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
|
||||
```
|
||||
|
||||
vllm bench command:
|
||||
@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
|
||||
|
||||
```bash
|
||||
# server
|
||||
VLLM_TORCH_CUDA_PROFILE=1 \
|
||||
nsys profile \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
--capture-range=cudaProfilerApi \
|
||||
--capture-range-end repeat \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
|
||||
|
||||
# client
|
||||
vllm bench serve \
|
||||
|
||||
@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
|
||||
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
|
||||
|
||||
```bash
|
||||
# Example UCX configuration, adjust according to your enviroment
|
||||
# Example UCX configuration, adjust according to your environment
|
||||
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
|
||||
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
|
||||
```
|
||||
|
||||
@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
|
||||
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
|
||||
return self.end_token_id in delta_token_ids
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# enable torch profiler, can also be set on cmd line
|
||||
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=1,
|
||||
profiler_config={
|
||||
"profiler": "torch",
|
||||
"torch_profiler_dir": "./vllm_profile",
|
||||
},
|
||||
)
|
||||
|
||||
llm.start_profile()
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ torchgeo==0.7.0
|
||||
mteb==2.1.2
|
||||
|
||||
# Data processing
|
||||
xgrammar==0.1.27
|
||||
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
|
||||
# Test async scheduling
|
||||
|
||||
# Utilities
|
||||
|
||||
@ -17,7 +17,6 @@ def test_compile():
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
@pytest.mark.xfail
|
||||
def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
|
||||
"""Test that Qwen2.5-VL vision submodules are compiled.
|
||||
|
||||
|
||||
@ -80,6 +80,8 @@ def test_compile_ranges(use_fresh_inductor_cache):
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
@ -112,6 +114,8 @@ def test_compile_config_get_compile_ranges():
|
||||
VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
@ -134,6 +138,8 @@ def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
import vllm.plugins
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
@ -237,13 +241,85 @@ def _generate_kernel_groupshape_combinations():
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = _generate_kernel_groupshape_combinations()
|
||||
|
||||
|
||||
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||
self.wscale = [
|
||||
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
||||
self.eps = eps
|
||||
|
||||
self.w8a8_block_fp8_linear = [
|
||||
TestBlockFP8Layer(
|
||||
GroupShape(128, 128),
|
||||
self.w[i],
|
||||
self.wscale[i],
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
||||
|
||||
x2 = self.w8a8_block_fp8_linear[0](y)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x2, resid, self.norm_weight[1], self.eps
|
||||
)
|
||||
|
||||
x3 = self.w8a8_block_fp8_linear[1](y2)
|
||||
|
||||
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x3, resid, self.norm_weight[2], self.eps
|
||||
)
|
||||
|
||||
x4 = self.w8a8_block_fp8_linear[2](y3)
|
||||
|
||||
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||
x4, resid, self.norm_weight[3], self.eps
|
||||
)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rms_norm,
|
||||
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||
]
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return []
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
||||
list(itertools.product([TestModel], [True, False], [True, False]))
|
||||
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
@ -253,9 +329,13 @@ def test_fusion_rmsnorm_quant(
|
||||
num_tokens,
|
||||
eps,
|
||||
kernel_groupshape,
|
||||
model_class,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
@ -290,7 +370,14 @@ def test_fusion_rmsnorm_quant(
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
if model_class is TestRmsnormGroupFp8QuantModel:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
||||
else:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
@ -325,7 +412,10 @@ def test_fusion_rmsnorm_quant(
|
||||
# there's a risk that the fused add doesn't get included in the
|
||||
# replacement and only the rms part gets fused with quant.
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if not enable_rms_norm_custom_op:
|
||||
if (
|
||||
not enable_rms_norm_custom_op
|
||||
and model_class is not TestRmsnormGroupFp8QuantModel
|
||||
):
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
|
||||
@ -5,9 +5,14 @@ import copy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.inductor_pass import (
|
||||
CallableInductorPass,
|
||||
InductorPass,
|
||||
pass_context,
|
||||
)
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
@ -42,35 +47,37 @@ class ProperPass(InductorPass):
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
# Some passes need dtype to be set
|
||||
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
|
||||
# Set the pass context as PassManager uuid uses it
|
||||
with pass_context(Range(start=1, end=8)):
|
||||
# Some passes need dtype to be set
|
||||
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
|
||||
# Check that UUID is different if the same pass is added 2x
|
||||
pass_manager.add(callable)
|
||||
uuid1 = pass_manager.uuid()
|
||||
pass_manager.add(callable)
|
||||
uuid2 = pass_manager.uuid()
|
||||
assert uuid1 != uuid2
|
||||
# Check that UUID is different if the same pass is added 2x
|
||||
pass_manager.add(callable)
|
||||
uuid1 = pass_manager.uuid()
|
||||
pass_manager.add(callable)
|
||||
uuid2 = pass_manager.uuid()
|
||||
assert uuid1 != uuid2
|
||||
|
||||
# UUID should be the same as the original one,
|
||||
# as we constructed in the same way.
|
||||
pass_manager2 = PostGradPassManager()
|
||||
pass_manager2.configure(config)
|
||||
pass_manager2.add(callable)
|
||||
assert uuid1 == pass_manager2.uuid()
|
||||
# UUID should be the same as the original one,
|
||||
# as we constructed in the same way.
|
||||
pass_manager2 = PostGradPassManager()
|
||||
pass_manager2.configure(config)
|
||||
pass_manager2.add(callable)
|
||||
assert uuid1 == pass_manager2.uuid()
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||
)
|
||||
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_act_quant
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
assert uuid1 != pass_manager3.uuid()
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||
)
|
||||
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_act_quant
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
assert uuid1 != pass_manager3.uuid()
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||
from vllm._aiter_ops import IS_AITER_FOUND
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
FUSED_OPS,
|
||||
@ -24,6 +25,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
return [FUSED_OPS[kNvfp4Quant]]
|
||||
|
||||
|
||||
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
)
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||
self.wscale = torch.rand(
|
||||
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
||||
+ [(TestSiluMulNvfp4QuantModel, False, False)],
|
||||
+ [
|
||||
(TestSiluMulNvfp4QuantModel, False, False),
|
||||
(TestSiluMulGroupFp8QuantModel, False, False),
|
||||
],
|
||||
)
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
|
||||
model_class: type[
|
||||
TestSiluMulFp8QuantModel
|
||||
| TestSiluMulNvfp4QuantModel
|
||||
| TestSiluMulGroupFp8QuantModel
|
||||
],
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
cuda_force_torch: bool,
|
||||
):
|
||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
@ -172,9 +216,15 @@ def test_fusion_silu_and_mul_quant(
|
||||
)
|
||||
|
||||
with set_current_vllm_config(config):
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||
if IS_AITER_FOUND:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
||||
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(
|
||||
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
||||
@ -193,12 +243,14 @@ def test_fusion_silu_and_mul_quant(
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
assert sum([p.matched_count for p in fusion_passes]) == 1
|
||||
|
||||
# In pre-nodes, quant op should be present and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
|
||||
)
|
||||
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
construct_chat_message_with_tool_call,
|
||||
_construct_single_message_from_response_item,
|
||||
construct_chat_messages_with_tool_call,
|
||||
convert_tool_responses_to_completions_format,
|
||||
)
|
||||
|
||||
@ -42,7 +44,43 @@ class TestResponsesUtils:
|
||||
|
||||
assert result == {"type": "function", "function": input_tool}
|
||||
|
||||
def test_construct_chat_message_with_tool_call(self):
|
||||
def test_construct_chat_messages_with_tool_call(self):
|
||||
"""Test construction of chat messages with tool calls."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Leroy Jenkins",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
mcp_tool_item = ResponseFunctionToolCall(
|
||||
id="mcp_123",
|
||||
call_id="call_123",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name="python",
|
||||
arguments='{"code": "123+456"}',
|
||||
)
|
||||
input_items = [reasoning_item, mcp_tool_item]
|
||||
messages = construct_chat_messages_with_tool_call(input_items)
|
||||
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message["role"] == "assistant"
|
||||
assert message["reasoning"] == "Leroy Jenkins"
|
||||
assert message["tool_calls"][0]["id"] == "call_123"
|
||||
assert message["tool_calls"][0]["function"]["name"] == "python"
|
||||
assert (
|
||||
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
|
||||
)
|
||||
|
||||
def test_construct_single_message_from_response_item(self):
|
||||
item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[],
|
||||
@ -56,7 +94,7 @@ class TestResponsesUtils:
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted_item = construct_chat_message_with_tool_call(item)
|
||||
formatted_item = _construct_single_message_from_response_item(item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert formatted_item["reasoning"] == "Leroy Jenkins"
|
||||
|
||||
@ -74,7 +112,7 @@ class TestResponsesUtils:
|
||||
status=None,
|
||||
)
|
||||
|
||||
formatted_item = construct_chat_message_with_tool_call(item)
|
||||
formatted_item = _construct_single_message_from_response_item(item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert (
|
||||
formatted_item["reasoning"]
|
||||
@ -88,7 +126,7 @@ class TestResponsesUtils:
|
||||
output="1234",
|
||||
status="completed",
|
||||
)
|
||||
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
|
||||
formatted_item = _construct_single_message_from_response_item(tool_call_output)
|
||||
assert formatted_item["role"] == "tool"
|
||||
assert formatted_item["content"] == "1234"
|
||||
assert formatted_item["tool_call_id"] == "temp"
|
||||
@ -102,7 +140,7 @@ class TestResponsesUtils:
|
||||
status=None,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
construct_chat_message_with_tool_call(item)
|
||||
_construct_single_message_from_response_item(item)
|
||||
|
||||
output_item = ResponseOutputMessage(
|
||||
id="msg_bf585bbbe3d500e0",
|
||||
@ -119,6 +157,6 @@ class TestResponsesUtils:
|
||||
type="message",
|
||||
)
|
||||
|
||||
formatted_item = construct_chat_message_with_tool_call(output_item)
|
||||
formatted_item = _construct_single_message_from_response_item(output_item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert formatted_item["content"] == "dongyi"
|
||||
|
||||
@ -7,7 +7,8 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
|
||||
]
|
||||
|
||||
|
||||
def get_attn_isa(
|
||||
block_size: int | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
if block_size and dtype:
|
||||
return _get_attn_isa(dtype, block_size)
|
||||
else:
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
return "neon"
|
||||
elif torch._C._cpu._is_amx_tile_supported():
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
|
||||
|
||||
# rand number generation takes too much time, cache rand tensors
|
||||
@functools.lru_cache(maxsize=128, typed=False)
|
||||
def tensor_cache(
|
||||
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", QTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["neon"])
|
||||
@pytest.mark.skipif(
|
||||
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
|
||||
reason="Not an Arm CPU.",
|
||||
)
|
||||
def test_varlen_with_paged_kv_normal_neon(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_softcap(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [True])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_alibi(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
)
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_sink(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
|
||||
@ -26,7 +26,14 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
devices = ["cpu"]
|
||||
if current_platform.is_cuda():
|
||||
devices.append("cuda")
|
||||
if current_platform.is_rocm():
|
||||
devices.append("hip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_mha_attn_platform(device: str):
|
||||
"""
|
||||
Test the attention selector between different platform and device.
|
||||
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
|
||||
@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
|
||||
.clamp(fp8_traits_min, fp8_traits_max)
|
||||
.to(FP8_DTYPE)
|
||||
)
|
||||
return ref_out, ref_scale.view((1, 1))
|
||||
return ref_out, ref_scale.view(1)
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(
|
||||
|
||||
@ -54,6 +54,10 @@ def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_cutlass_matmul():
|
||||
# Test simple case where weight.shape % 128 != 0,
|
||||
@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
|
||||
@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
|
||||
@ -21,6 +21,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion():
|
||||
connector.fetch_image(broken_img)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3, reruns_delay=5)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
|
||||
}
|
||||
)
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
try:
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
|
||||
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
|
||||
@ -10,10 +10,14 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODELS = [
|
||||
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
|
||||
# FP8 weight reloading does not support online quantization
|
||||
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
|
||||
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
|
||||
# any postprocessing that is applied to the weights such as padding and repacking
|
||||
# (excluding device sharding) must also be applied to the reloaded weights
|
||||
#
|
||||
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||
def test_fp8_reloading(
|
||||
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
|
||||
):
|
||||
if is_checkpoint_fp8_serialized is False:
|
||||
pytest.skip("FP8 weight reloading does not support online quantization")
|
||||
|
||||
if method_cls is Fp8MoEMethod and weight_block_size is None:
|
||||
pytest.skip(
|
||||
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
|
||||
"If this is your use case, consider using a restore function like #26327"
|
||||
)
|
||||
|
||||
with torch.device("cuda:0"):
|
||||
config = Fp8Config(
|
||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
if method_cls is Fp8LinearMethod:
|
||||
layer = torch.nn.Linear(1, 1)
|
||||
method = method_cls(config)
|
||||
method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=1,
|
||||
output_partition_sizes=[1],
|
||||
input_size=1,
|
||||
output_size=1,
|
||||
params_dtype=torch.bfloat16,
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
else:
|
||||
layer = FusedMoE(
|
||||
num_experts=1,
|
||||
top_k=1,
|
||||
hidden_size=1,
|
||||
intermediate_size=1,
|
||||
)
|
||||
method = method_cls(config, layer)
|
||||
method.create_weights(
|
||||
layer=layer,
|
||||
num_experts=1,
|
||||
hidden_size=1,
|
||||
intermediate_size_per_partition=1,
|
||||
params_dtype=torch.bfloat16,
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
method.use_marlin = use_marlin
|
||||
|
||||
# capture weights format during loading
|
||||
original_metadata = [
|
||||
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
|
||||
for name, param in layer.named_parameters()
|
||||
]
|
||||
|
||||
# test loading
|
||||
for name, shape, _ in original_metadata:
|
||||
param = getattr(layer, name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
# test reloading works after loading
|
||||
# assuming that no reshaping occurred
|
||||
for name, shape, original_weight_loader in original_metadata:
|
||||
param = getattr(layer, name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert weight_loader is original_weight_loader
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_reasoning_end_streaming(self, test_tokenizer):
|
||||
"""Test the is_reasoning_end_streaming method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
end_token_id = parser.end_token_id
|
||||
start_token_id = parser.start_token_id
|
||||
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
|
||||
is True
|
||||
)
|
||||
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
|
||||
assert parser.is_reasoning_end_streaming([], []) is False
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id], [end_token_id]
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
|
||||
[2],
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id, 2, 2], [2]
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_extract_content_ids(self, test_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
|
||||
input_tokens = tokenizer.tokenize(input_text)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
assert parser.is_reasoning_end(input_ids) is True
|
||||
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
|
||||
|
||||
# Test extract_content_ids returns all input_ids
|
||||
assert parser.extract_content_ids(input_ids) == input_ids
|
||||
|
||||
@ -615,6 +615,7 @@ def test_extract_tool_calls_streaming(
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"content_before_tool",
|
||||
"complex",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
@ -673,6 +674,21 @@ def test_extract_tool_calls_streaming(
|
||||
],
|
||||
"bla",
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_one_chunk(
|
||||
|
||||
@ -53,9 +53,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
@ -1381,8 +1378,8 @@ class TestBlockFP8Layer:
|
||||
"""
|
||||
Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp has a similar API to
|
||||
FP8ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
|
||||
This is a workaround until W8A8BlockFp8LinearOp implements
|
||||
ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1390,20 +1387,21 @@ class TestBlockFP8Layer:
|
||||
group_shape: GroupShape,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
cutlass_block_fp8_supported: bool = False,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
):
|
||||
self.kernel = None # For compatibility with TestFP8Layer interface
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
self.weight = weight
|
||||
self.weight_scale = weight_scale
|
||||
self.input_scale = input_scale
|
||||
|
||||
def forward(
|
||||
def __call__(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
|
||||
@ -106,8 +106,8 @@ def create_common_attn_metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=batch_spec.batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key == BatchDescriptor(num_tokens=15)
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
# 4. disable_full should have a fall back mode (e.g., cascade attention)
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
|
||||
)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -10,6 +10,7 @@ from utils import (
|
||||
BACKENDS,
|
||||
_extract_step_logprobs,
|
||||
_random_prompt,
|
||||
is_device_capability_below_90,
|
||||
resolve_model_name,
|
||||
skip_unsupported,
|
||||
)
|
||||
@ -17,6 +18,8 @@ from utils import (
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use a few test prompts
|
||||
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
@ -11,8 +11,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
|
||||
# Supports testing on Ampere and Ada Lovelace devices.
|
||||
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
|
||||
reason="Requires CUDA and >= Ampere (SM80)",
|
||||
)
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def is_device_capability_below_90() -> bool:
|
||||
return not current_platform.has_device_capability(90)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
|
||||
@ -70,6 +71,18 @@ def test_without_spec_decoding(
|
||||
(True, "uni", True, None, True),
|
||||
]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# On ROCm, Only test with structured_outputs (deterministic)
|
||||
# and skip chunk_prefill (more variable).
|
||||
test_configs = [
|
||||
cfg
|
||||
for cfg in test_configs
|
||||
if not cfg[4] # skip chunk_prefill=True
|
||||
]
|
||||
test_sampling_params = [
|
||||
p for p in test_sampling_params if p.get("structured_outputs") is not None
|
||||
]
|
||||
|
||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
(True, "uni", True, spec_config_short, True),
|
||||
]
|
||||
|
||||
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
|
||||
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||
run_tests(
|
||||
monkeypatch,
|
||||
MTP_MODEL,
|
||||
test_configs,
|
||||
test_sampling_params,
|
||||
is_testing_with_spec_decoding=True,
|
||||
)
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
@ -117,13 +137,23 @@ def run_tests(
|
||||
model: str,
|
||||
test_configs: list[tuple],
|
||||
test_sampling_params: list[dict[str, Any]],
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# avoid precision errors
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# lock matmul precision to full FP32
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
outputs: list[tuple[str, list, list]] = []
|
||||
for n, (
|
||||
@ -143,6 +173,7 @@ def run_tests(
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
|
||||
@ -172,17 +203,34 @@ def run_tests(
|
||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
|
||||
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
|
||||
# logprobs comparison when logprobs are requested
|
||||
skip_logprobs_check = (
|
||||
current_platform.is_rocm()
|
||||
and params.get("logprobs")
|
||||
and is_testing_with_spec_decoding
|
||||
)
|
||||
if not skip_logprobs_check:
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
and test_acceptance_rate is not None
|
||||
):
|
||||
if "spec_mml=None" in test_config:
|
||||
# Preemption causes more variance in acceptance rates
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and "preemption=True" in test_config
|
||||
):
|
||||
tolerance = 0.10
|
||||
else:
|
||||
tolerance = 0.05
|
||||
assert (
|
||||
test_acceptance_rate > base_acceptance_rate
|
||||
or test_acceptance_rate
|
||||
== pytest.approx(base_acceptance_rate, rel=5e-2)
|
||||
== pytest.approx(base_acceptance_rate, rel=tolerance)
|
||||
)
|
||||
else:
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
@ -213,6 +261,7 @@ def run_test(
|
||||
async_scheduling: bool,
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
@ -231,6 +280,15 @@ def run_test(
|
||||
print("-" * 80)
|
||||
print(f"---- TESTING {test_str}: {test_config}")
|
||||
print("-" * 80)
|
||||
|
||||
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
|
||||
# spec decoding test (TRITON_ATTN) for better precision.
|
||||
# On others: always use float32.
|
||||
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
|
||||
dtype = "float16"
|
||||
else:
|
||||
dtype = "float32"
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
@ -240,7 +298,7 @@ def run_test(
|
||||
# enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=executor,
|
||||
dtype="float32", # avoid precision errors
|
||||
dtype=dtype,
|
||||
speculative_config=spec_config,
|
||||
disable_log_stats=False,
|
||||
**cache_arg,
|
||||
@ -300,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool:
|
||||
|
||||
|
||||
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
||||
return len(lps_a) == len(lps_b) and all(
|
||||
a.decoded_token == b.decoded_token
|
||||
and a.rank == b.rank
|
||||
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
|
||||
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
|
||||
if current_platform.is_rocm():
|
||||
# ROCm has higher numerical variance
|
||||
# due to use of float16.
|
||||
rel_tol, abs_tol = 5e-2, 1e-5
|
||||
else:
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
return (
|
||||
len(lps_a) == len(lps_b)
|
||||
and lps_a.keys() == lps_b.keys()
|
||||
and all(
|
||||
a.decoded_token == b.decoded_token
|
||||
and a.rank == b.rank
|
||||
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
|
||||
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test that verifies no implicit GPU-CPU synchronization occurs during
|
||||
speculative decoding generation under expected conditions.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_tracker():
|
||||
"""
|
||||
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||
"""
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
# Shared counter for cross-process communication (inherited by fork)
|
||||
sync_count = multiprocessing.Value("i", 0)
|
||||
|
||||
# Save original property
|
||||
original_prop = CommonAttentionMetadata.seq_lens_cpu
|
||||
original_fget = original_prop.fget
|
||||
|
||||
# Create tracking wrapper
|
||||
def tracking_seq_lens_cpu(self):
|
||||
if self._seq_lens_cpu is None:
|
||||
# Increment counter
|
||||
with sync_count.get_lock():
|
||||
sync_count.value += 1
|
||||
count = sync_count.value
|
||||
# Print stack trace immediately (shows in subprocess output)
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
|
||||
print(f"{'=' * 60}", file=sys.stderr)
|
||||
traceback.print_stack(file=sys.stderr)
|
||||
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
return original_fget(self)
|
||||
|
||||
# Apply patch
|
||||
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
|
||||
|
||||
class SyncTracker:
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return sync_count.value
|
||||
|
||||
def assert_no_sync(self, msg: str = ""):
|
||||
count = sync_count.value
|
||||
assert count == 0, (
|
||||
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
|
||||
f"{count} times. See stack traces above. {msg}"
|
||||
)
|
||||
|
||||
yield SyncTracker()
|
||||
|
||||
# Restore original property
|
||||
CommonAttentionMetadata.seq_lens_cpu = original_prop
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
|
||||
SPEC_DECODE_CONFIGS = [
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"eagle3",
|
||||
2,
|
||||
id="eagle3-llama",
|
||||
),
|
||||
pytest.param(
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
"eagle",
|
||||
2,
|
||||
id="eagle-mla-deepseek",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,spec_model,method,num_spec_tokens",
|
||||
SPEC_DECODE_CONFIGS,
|
||||
)
|
||||
def test_no_sync_with_spec_decode(
|
||||
sync_tracker,
|
||||
model: str,
|
||||
spec_model: str,
|
||||
method: str,
|
||||
num_spec_tokens: int,
|
||||
):
|
||||
"""
|
||||
Test that no implicit GPU-CPU sync occurs during speculative decoding
|
||||
generation.
|
||||
"""
|
||||
# Import vLLM AFTER sync_tracker fixture has applied the patch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_model_len=256,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": num_spec_tokens,
|
||||
"model": spec_model,
|
||||
},
|
||||
enforce_eager=True,
|
||||
async_scheduling=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
["Hello, my name is"],
|
||||
SamplingParams(temperature=0, max_tokens=10),
|
||||
)
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert len(outputs[0].outputs[0].text) > 0
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
sync_tracker.assert_no_sync()
|
||||
@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
|
||||
# Expect the acceptance rate to improve.
|
||||
assert first_accept_rate < last_accept_rate
|
||||
|
||||
# Heuristic: expect at least 82.5% acceptance rate at the end.
|
||||
assert last_accept_rate > 0.825
|
||||
# Heuristic: expect at least 80.0% acceptance rate at the end.
|
||||
assert last_accept_rate > 0.80
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -88,8 +88,8 @@ def forward_attention(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc.cpu(),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
num_computed_tokens_cpu=context_lens.cpu(),
|
||||
_seq_lens_cpu=seq_lens.cpu(),
|
||||
_num_computed_tokens_cpu=context_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
|
||||
request.use_structured_output = True
|
||||
request.prompt_token_ids = [1, 2, 3, 4, 5]
|
||||
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
request.num_computed_tokens = 5
|
||||
return request
|
||||
|
||||
def test_should_fill_bitmask_with_enable_in_reasoning(
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.profiler.gpu_profiler import WorkerProfiler
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.profiler.wrapper import WorkerProfiler
|
||||
|
||||
|
||||
class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
A basic implementation of a worker profiler for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, profiler_config: ProfilerConfig):
|
||||
self.start_call_count = 0
|
||||
self.stop_call_count = 0
|
||||
self.should_fail_start = False
|
||||
super().__init__()
|
||||
super().__init__(profiler_config)
|
||||
|
||||
def _start(self) -> None:
|
||||
if self.should_fail_start:
|
||||
@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
self.stop_call_count += 1
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_mocks():
|
||||
"""Fixture to reset mocks and env variables before each test."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 0
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 0
|
||||
@pytest.fixture
|
||||
def default_profiler_config():
|
||||
return ProfilerConfig(
|
||||
profiler="torch",
|
||||
torch_profiler_dir="/tmp/mock",
|
||||
delay_iterations=0,
|
||||
max_iterations=0,
|
||||
)
|
||||
|
||||
|
||||
def test_immediate_start_stop():
|
||||
def test_immediate_start_stop(default_profiler_config):
|
||||
"""Test standard start without delay."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
assert profiler._active is True
|
||||
@ -48,10 +50,10 @@ def test_immediate_start_stop():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start():
|
||||
def test_delayed_start(default_profiler_config):
|
||||
"""Test that profiler waits for N steps before actually starting."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# User requests start
|
||||
profiler.start()
|
||||
@ -71,10 +73,10 @@ def test_delayed_start():
|
||||
assert profiler.start_call_count == 1
|
||||
|
||||
|
||||
def test_max_iterations():
|
||||
def test_max_iterations(default_profiler_config):
|
||||
"""Test that profiler stops automatically after max iterations."""
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.max_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
@ -95,12 +97,11 @@ def test_max_iterations():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start_and_max_iters():
|
||||
def test_delayed_start_and_max_iters(default_profiler_config):
|
||||
"""Test combined delayed start and max iterations."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
default_profiler_config.delay_iterations = 2
|
||||
default_profiler_config.max_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.start()
|
||||
|
||||
# Step 1
|
||||
@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_idempotency():
|
||||
def test_idempotency(default_profiler_config):
|
||||
"""Test that calling start/stop multiple times doesn't break logic."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Double Start
|
||||
profiler.start()
|
||||
@ -142,10 +143,10 @@ def test_idempotency():
|
||||
assert profiler.stop_call_count == 1 # Should only stop once
|
||||
|
||||
|
||||
def test_step_inactive():
|
||||
def test_step_inactive(default_profiler_config):
|
||||
"""Test that stepping while inactive does nothing."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 2
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Not started yet
|
||||
profiler.step()
|
||||
@ -155,9 +156,9 @@ def test_step_inactive():
|
||||
assert profiler.start_call_count == 0
|
||||
|
||||
|
||||
def test_start_failure():
|
||||
def test_start_failure(default_profiler_config):
|
||||
"""Test behavior when the underlying _start method raises exception."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
profiler.should_fail_start = True
|
||||
|
||||
profiler.start()
|
||||
@ -168,9 +169,9 @@ def test_start_failure():
|
||||
assert profiler.start_call_count == 0 # Logic failed inside start
|
||||
|
||||
|
||||
def test_shutdown():
|
||||
def test_shutdown(default_profiler_config):
|
||||
"""Test that shutdown calls stop only if running."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
# Case 1: Not running
|
||||
profiler.shutdown()
|
||||
@ -182,10 +183,10 @@ def test_shutdown():
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_mixed_delay_and_stop():
|
||||
def test_mixed_delay_and_stop(default_profiler_config):
|
||||
"""Test manual stop during the delay period."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 5
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
default_profiler_config.delay_iterations = 5
|
||||
profiler = ConcreteWorkerProfiler(default_profiler_config)
|
||||
|
||||
profiler.start()
|
||||
profiler.step()
|
||||
|
||||
@ -9,6 +9,8 @@ import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
|
||||
# we keep this global outside to not cause torch compile breaks.
|
||||
IS_AITER_FOUND = is_aiter_found()
|
||||
|
||||
# Can't use dtypes.fp8 directly inside an op
|
||||
# because it returns wrong result on gfx942.
|
||||
# This is a workaround to get the correct FP8 dtype.
|
||||
# This might because that the get_gfx() is wrapped as a custom op.
|
||||
if IS_AITER_FOUND:
|
||||
from aiter import dtypes
|
||||
|
||||
AITER_FP8_DTYPE = dtypes.fp8
|
||||
|
||||
|
||||
def if_aiter_supported(func: Callable) -> Callable:
|
||||
"""Decorator that only executes the function if
|
||||
@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||
from aiter import QuantType, dtypes, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_fake(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter import dtypes
|
||||
|
||||
M, N = x.shape
|
||||
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
(N + group_size - 1) // group_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -467,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
def _rocm_aiter_per_tensor_quant_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.quant import per_tensor_quant_hip
|
||||
|
||||
return per_tensor_quant_hip(x, scale, quant_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_per_tensor_quant_fake(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
|
||||
1, dtype=torch.float32, device=x.device
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_per_token_quant_impl(
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.quant import dynamic_per_token_scaled_quant
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
|
||||
out_shape = x.shape
|
||||
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
|
||||
if scale is None:
|
||||
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
|
||||
dynamic_per_token_scaled_quant(
|
||||
out,
|
||||
x,
|
||||
scale,
|
||||
scale_ub=None,
|
||||
shuffle_scale=False,
|
||||
num_rows=None,
|
||||
num_rows_factor=1,
|
||||
)
|
||||
return out, scale
|
||||
|
||||
|
||||
def _rocm_aiter_per_token_quant_fake(
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
out_shape = x.shape
|
||||
return (
|
||||
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
|
||||
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||
|
||||
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
res1=residual,
|
||||
)
|
||||
return (x_quant, x_quant_scales, res)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||
return (
|
||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
torch.empty_like(residual, device=residual.device),
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||
|
||||
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
res1=None,
|
||||
)
|
||||
return (x_quant, x_quant_scales)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||
return (
|
||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||
from aiter import QuantType, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_fake(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
(N + group_size - 1) // group_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
|
||||
|
||||
return act_mul_and_fp8_group_quant(
|
||||
x,
|
||||
activation="silu",
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
assert N % 2 == 0
|
||||
N_half = N // 2
|
||||
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
(N_half + group_size - 1) // group_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
@ -502,7 +672,7 @@ class rocm_aiter_ops:
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
||||
return cls.is_linear_enabled()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
@ -577,14 +747,6 @@ class rocm_aiter_ops:
|
||||
)
|
||||
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_group_fp8_quant",
|
||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||
@ -644,27 +806,62 @@ class rocm_aiter_ops:
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=_rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
||||
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
||||
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
|
||||
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
|
||||
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
|
||||
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
|
||||
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_group_fp8_quant",
|
||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_per_tensor_quant",
|
||||
op_func=_rocm_aiter_per_tensor_quant_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_per_tensor_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_per_token_quant",
|
||||
op_func=_rocm_aiter_per_token_quant_impl,
|
||||
mutates_args=["scale"],
|
||||
fake_impl=_rocm_aiter_per_token_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
@ -859,6 +1056,22 @@ class rocm_aiter_ops:
|
||||
kv_scale=kv_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def per_tensor_quant(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
|
||||
|
||||
@staticmethod
|
||||
def per_token_quant(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -1726,7 +1726,7 @@ def scaled_fp8_quant(
|
||||
output, input, scale, scale_ub
|
||||
)
|
||||
else:
|
||||
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
|
||||
scale = torch.empty(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
|
||||
@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
|
||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
try:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
|
||||
@ -103,7 +103,7 @@ def create_cross_attention_backend(
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata.seq_lens_cpu = torch.from_numpy(
|
||||
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler."
|
||||
)
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
if args.profile and not engine_args.profiler_config.profiler == "torch":
|
||||
raise ValueError(
|
||||
"The torch profiler is not enabled. Please provide profiler_config."
|
||||
)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -144,7 +142,7 @@ def main(args: argparse.Namespace):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profile_dir = engine_args.profiler_config.torch_profiler_dir
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
|
||||
@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Torch Profiler. The env variable "
|
||||
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
@ -8,15 +10,17 @@ import json
|
||||
import types
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.config.utils import Range
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.utils import Range
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
|
||||
@ -5,6 +5,7 @@ import functools
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
@ -53,8 +53,27 @@ class PiecewiseBackend:
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
# TODO: we need to generalize encoder compilation to other models
|
||||
self.is_encoder_compilation = vllm_backend.prefix in [
|
||||
"Qwen2_5_VisionPatchEmbed",
|
||||
"Qwen2_5_VisionPatchMerger",
|
||||
"Qwen2_5_VisionBlock",
|
||||
]
|
||||
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
if self.is_encoder_compilation:
|
||||
# For encoder compilation we use the max int32 value
|
||||
# to set the upper bound of the compile ranges
|
||||
max_int32 = 2**31 - 1
|
||||
last_compile_range = self.compile_ranges[-1]
|
||||
assert (
|
||||
last_compile_range.end
|
||||
== vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.compile_ranges[-1] = Range(
|
||||
start=last_compile_range.start, end=max_int32
|
||||
)
|
||||
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
|
||||
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import empty_bf16
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherSiluAndMul
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
||||
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||
)
|
||||
|
||||
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern:
|
||||
"""
|
||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||
ops into an aiter rms_norm_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_op = quant_op
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
||||
|
||||
at2 = self.quant_op(at1, 128)
|
||||
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at = AITER_RMS_GROUP_QUANT_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
return at[0], at[1]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern:
|
||||
"""
|
||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_op = quant_op
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at1 = AITER_RMS_ADD_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
at2 = self.quant_op(at1[0], 128)
|
||||
|
||||
# result, scale, residual
|
||||
return at2[0], at2[1], at1[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
# result, scale, residual
|
||||
return at[0], at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + dynamic group fp8 quant
|
||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, quant_op
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
fusion_patterns = [
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
]
|
||||
return self.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_op: OpOverload):
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
):
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
):
|
||||
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
inputs = [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
fusion_patterns = [
|
||||
ActivationQuantPattern,
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
]
|
||||
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||
@ -24,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.config.parallel import EPLBConfig, ParallelConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.profiler import ProfilerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
@ -89,6 +90,8 @@ __all__ = [
|
||||
"SpeechToTextConfig",
|
||||
# From vllm.config.structured_outputs
|
||||
"StructuredOutputsConfig",
|
||||
# From vllm.config.profiler
|
||||
"ProfilerConfig",
|
||||
# From vllm.config.utils
|
||||
"ConfigType",
|
||||
"SupportsMetricsInfo",
|
||||
|
||||
199
vllm/config/profiler.py
Normal file
199
vllm/config/profiler.py
Normal file
@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ProfilerKind = Literal["torch", "cuda"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ProfilerConfig:
|
||||
"""Dataclass which contains profiler config for the engine."""
|
||||
|
||||
profiler: ProfilerKind | None = None
|
||||
"""Which profiler to use. Defaults to None. Options are:
|
||||
|
||||
- 'torch': Use PyTorch profiler.\n
|
||||
- 'cuda': Use CUDA profiler."""
|
||||
|
||||
torch_profiler_dir: str = ""
|
||||
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
|
||||
worker's traces (CPU & GPU) will be saved under this directory. Note that
|
||||
it must be an absolute path."""
|
||||
|
||||
torch_profiler_with_stack: bool = True
|
||||
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
|
||||
|
||||
torch_profiler_with_flops: bool = False
|
||||
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_use_gzip: bool = True
|
||||
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
|
||||
|
||||
torch_profiler_dump_cuda_time_total: bool = True
|
||||
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
|
||||
|
||||
torch_profiler_record_shapes: bool = False
|
||||
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_with_memory: bool = False
|
||||
"""If `True`, enables memory profiling in the torch profiler.
|
||||
Disabled by default."""
|
||||
|
||||
ignore_frontend: bool = False
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
|
||||
since the front-end profiling does not track iterations and will capture the
|
||||
entire range.
|
||||
"""
|
||||
|
||||
delay_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of engine iterations to skip before starting profiling.
|
||||
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
|
||||
"""
|
||||
|
||||
max_iterations: int = Field(default=0, ge=0)
|
||||
"""Maximum number of engine iterations to profile after starting profiling.
|
||||
Defaults to 0, meaning no limit.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
|
||||
"""Get field from env var if set, with deprecation warning."""
|
||||
|
||||
if envs.is_set(env_var_name):
|
||||
value = getattr(envs, env_var_name)
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in "
|
||||
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
|
||||
"--profiler-config.%s command line argument or "
|
||||
"ProfilerConfig(%s=...) config field instead.",
|
||||
env_var_name,
|
||||
field_name,
|
||||
field_name,
|
||||
)
|
||||
return value
|
||||
return None
|
||||
|
||||
def _set_from_env_if_set(
|
||||
self,
|
||||
field_name: str,
|
||||
env_var_name: str,
|
||||
to_bool: bool = True,
|
||||
to_int: bool = False,
|
||||
) -> None:
|
||||
"""Set field from env var if set, with deprecation warning."""
|
||||
value = self._get_from_env_if_set(field_name, env_var_name)
|
||||
if value is not None:
|
||||
if to_bool:
|
||||
value = value == "1"
|
||||
if to_int:
|
||||
value = int(value)
|
||||
setattr(self, field_name, value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_profiler_config(self) -> Self:
|
||||
maybe_use_cuda_profiler = self._get_from_env_if_set(
|
||||
"profiler", "VLLM_TORCH_CUDA_PROFILE"
|
||||
)
|
||||
if maybe_use_cuda_profiler is not None:
|
||||
self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None
|
||||
else:
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False
|
||||
)
|
||||
if self.torch_profiler_dir:
|
||||
self.profiler = "torch"
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_record_shapes",
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_memory",
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_stack",
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_with_flops",
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"ignore_frontend",
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_use_gzip",
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP",
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"torch_profiler_dump_cuda_time_total",
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL",
|
||||
)
|
||||
|
||||
self._set_from_env_if_set(
|
||||
"delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True
|
||||
)
|
||||
|
||||
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
|
||||
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
|
||||
logger.warning_once(
|
||||
"Using 'torch' profiler with delay_iterations or max_iterations "
|
||||
"while ignore_frontend is False may result in high overhead."
|
||||
)
|
||||
|
||||
profiler_dir = self.torch_profiler_dir
|
||||
if profiler_dir and self.profiler != "torch":
|
||||
raise ValueError(
|
||||
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
|
||||
)
|
||||
if self.profiler == "torch" and not profiler_dir:
|
||||
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
|
||||
|
||||
if profiler_dir:
|
||||
is_gs_path = (
|
||||
profiler_dir.startswith("gs://")
|
||||
and profiler_dir[5:]
|
||||
and profiler_dir[5] != "/"
|
||||
)
|
||||
if not is_gs_path:
|
||||
self.torch_profiler_dir = os.path.abspath(
|
||||
os.path.expanduser(profiler_dir)
|
||||
)
|
||||
|
||||
return self
|
||||
@ -39,6 +39,7 @@ from .lora import LoRAConfig
|
||||
from .model import ModelConfig
|
||||
from .observability import ObservabilityConfig
|
||||
from .parallel import ParallelConfig
|
||||
from .profiler import ProfilerConfig
|
||||
from .scheduler import SchedulerConfig
|
||||
from .speculative import SpeculativeConfig
|
||||
from .structured_outputs import StructuredOutputsConfig
|
||||
@ -218,6 +219,8 @@ class VllmConfig:
|
||||
You can specify the full compilation config like so:
|
||||
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||
"""
|
||||
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
|
||||
"""Profiling configuration."""
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
@ -296,6 +299,8 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
if self.structured_outputs_config:
|
||||
vllm_factors.append(self.structured_outputs_config.compute_hash())
|
||||
if self.profiler_config:
|
||||
vllm_factors.append(self.profiler_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
vllm_factors.append(self.observability_config.compute_hash())
|
||||
@ -1042,8 +1047,14 @@ class VllmConfig:
|
||||
self.compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
if max_cudagraph_capture_size is None:
|
||||
decode_query_len = 1
|
||||
if (
|
||||
self.speculative_config
|
||||
and self.speculative_config.num_speculative_tokens
|
||||
):
|
||||
decode_query_len += self.speculative_config.num_speculative_tokens
|
||||
max_cudagraph_capture_size = min(
|
||||
self.scheduler_config.max_num_seqs * 2, 512
|
||||
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
|
||||
)
|
||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
||||
|
||||
@ -50,6 +50,7 @@ from vllm.config import (
|
||||
ObservabilityConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
ProfilerConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
@ -532,6 +533,8 @@ class EngineArgs:
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")
|
||||
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
|
||||
@ -1164,7 +1167,7 @@ class EngineArgs:
|
||||
vllm_group.add_argument(
|
||||
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
|
||||
)
|
||||
|
||||
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
|
||||
vllm_group.add_argument(
|
||||
"--optimization-level", **vllm_kwargs["optimization_level"]
|
||||
)
|
||||
@ -1782,6 +1785,7 @@ class EngineArgs:
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
ec_transfer_config=self.ec_transfer_config,
|
||||
profiler_config=self.profiler_config,
|
||||
additional_config=self.additional_config,
|
||||
optimization_level=self.optimization_level,
|
||||
)
|
||||
|
||||
@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
|
||||
# These constants help mitigate header abuse attacks
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
||||
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
||||
|
||||
MCP_PREFIX = "mcp_"
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
|
||||
result_str = result.content[0].text
|
||||
|
||||
message = ResponseFunctionToolCallOutputItem(
|
||||
id=f"fco_{random_uuid()}",
|
||||
id=f"mcpo_{random_uuid()}",
|
||||
type="function_call_output",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
output=result_str,
|
||||
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
|
||||
if not self.parser.response_messages:
|
||||
return []
|
||||
last_msg = self.parser.response_messages[-1]
|
||||
# change this to a mcp_ function call
|
||||
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
|
||||
self.parser.response_messages[-1] = last_msg
|
||||
if last_msg.name == "code_interpreter":
|
||||
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||
elif last_msg.name == "web_search_preview":
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.beam_search import (
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
PoolerConfig,
|
||||
ProfilerConfig,
|
||||
StructuredOutputsConfig,
|
||||
is_init_field,
|
||||
)
|
||||
@ -211,6 +212,7 @@ class LLM:
|
||||
structured_outputs_config: dict[str, Any]
|
||||
| StructuredOutputsConfig
|
||||
| None = None,
|
||||
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
|
||||
kv_cache_memory_bytes: int | None = None,
|
||||
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
|
||||
logits_processors: list[str | type[LogitsProcessor]] | None = None,
|
||||
@ -282,6 +284,20 @@ class LLM:
|
||||
else:
|
||||
structured_outputs_instance = StructuredOutputsConfig()
|
||||
|
||||
if profiler_config is not None:
|
||||
if isinstance(profiler_config, dict):
|
||||
profiler_config_instance = ProfilerConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in profiler_config.items()
|
||||
if is_init_field(ProfilerConfig, k)
|
||||
}
|
||||
)
|
||||
else:
|
||||
profiler_config_instance = profiler_config
|
||||
else:
|
||||
profiler_config_instance = ProfilerConfig()
|
||||
|
||||
# warn about single-process data parallel usage.
|
||||
_dp_size = int(kwargs.get("data_parallel_size", 1))
|
||||
_distributed_executor_backend = kwargs.get("distributed_executor_backend")
|
||||
@ -324,6 +340,7 @@ class LLM:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooler_config=pooler_config,
|
||||
structured_outputs_config=structured_outputs_instance,
|
||||
profiler_config=profiler_config_instance,
|
||||
compilation_config=compilation_config_instance,
|
||||
logits_processors=logits_processors,
|
||||
**kwargs,
|
||||
|
||||
@ -1339,6 +1339,7 @@ class OpenAIServing:
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
request_prompt = request_prompts[0]
|
||||
prompt_text, _, _ = self._get_prompt_components(request_prompt)
|
||||
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
|
||||
@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if not _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(
|
||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
@ -148,23 +143,24 @@ class MistralToolParser(ToolParser):
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
if self.fn_name_regex:
|
||||
if not self._is_pre_v11:
|
||||
function_call_arr = []
|
||||
for single_tool_content in model_output.split(self.bot_token):
|
||||
matches = self.fn_name_regex.findall(single_tool_content)
|
||||
if "{" not in single_tool_content:
|
||||
continue
|
||||
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
end_name = single_tool_content.find("{")
|
||||
fn_name, args = (
|
||||
single_tool_content[:end_name],
|
||||
single_tool_content[end_name:],
|
||||
)
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
from openai.types.responses.tool import Tool
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionMessageParam,
|
||||
ResponseInputOutputItem,
|
||||
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
|
||||
)
|
||||
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
||||
mcp_message = McpCall(
|
||||
id=f"mcp_{random_uuid()}",
|
||||
id=f"{MCP_PREFIX}{random_uuid()}",
|
||||
arguments=output_messages[-1].arguments,
|
||||
name=output_messages[-1].name,
|
||||
server_label=output_messages[
|
||||
-1
|
||||
].name, # TODO: store the server label
|
||||
type="mcp_call",
|
||||
type=f"{MCP_PREFIX}call",
|
||||
status="completed",
|
||||
output=message.output,
|
||||
# TODO: support error output
|
||||
@ -98,12 +99,63 @@ def construct_input_messages(
|
||||
if isinstance(request_input, str):
|
||||
messages.append({"role": "user", "content": request_input})
|
||||
else:
|
||||
for item in request_input:
|
||||
messages.append(construct_chat_message_with_tool_call(item))
|
||||
input_messages = construct_chat_messages_with_tool_call(request_input)
|
||||
messages.extend(input_messages)
|
||||
return messages
|
||||
|
||||
|
||||
def construct_chat_message_with_tool_call(
|
||||
def _maybe_combine_reasoning_and_tool_call(
|
||||
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Many models treat MCP calls and reasoning as a single message.
|
||||
This function checks if the last message is a reasoning message and
|
||||
the current message is a tool call"""
|
||||
if not (
|
||||
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
|
||||
):
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
last_message = messages[-1]
|
||||
if not (
|
||||
last_message.get("role") == "assistant"
|
||||
and last_message.get("reasoning") is not None
|
||||
):
|
||||
return None
|
||||
|
||||
last_message["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=item.call_id,
|
||||
function=FunctionCallTool(
|
||||
name=item.name,
|
||||
arguments=item.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
return last_message
|
||||
|
||||
|
||||
def construct_chat_messages_with_tool_call(
|
||||
input_messages: list[ResponseInputOutputItem],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""This function wraps _construct_single_message_from_response_item
|
||||
Because some chatMessages come from multiple response items
|
||||
for example a reasoning item and a MCP tool call are two response items
|
||||
but are one chat message
|
||||
"""
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
for item in input_messages:
|
||||
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
if maybe_combined_message is not None:
|
||||
messages[-1] = maybe_combined_message
|
||||
else:
|
||||
messages.append(_construct_single_message_from_response_item(item))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _construct_single_message_from_response_item(
|
||||
item: ResponseInputOutputItem,
|
||||
) -> ChatCompletionMessageParam:
|
||||
if isinstance(item, ResponseFunctionToolCall):
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -35,15 +35,12 @@ async def stop_profile(raw_request: Request):
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
profiler_config = getattr(app.state.args, "profiler_config", None)
|
||||
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
|
||||
if profiler_config is not None and profiler_config.profiler is not None:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
"Profiler with mode '%s' is enabled in the "
|
||||
"API server. This should ONLY be used for local development!",
|
||||
profiler_config.profiler,
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
app.include_router(router)
|
||||
|
||||
119
vllm/envs.py
119
vllm/envs.py
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
VLLM_MAIN_CUDA_VERSION: str = "12.9"
|
||||
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
|
||||
MAX_JOBS: str | None = None
|
||||
NVCC_THREADS: str | None = None
|
||||
VLLM_USE_PRECOMPILED: bool = False
|
||||
@ -88,20 +89,23 @@ if TYPE_CHECKING:
|
||||
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
|
||||
VLLM_PLUGINS: list[str] | None = None
|
||||
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
|
||||
VLLM_TORCH_CUDA_PROFILE: bool = False
|
||||
# Deprecated env variables for profiling, kept for backward compatibility
|
||||
# See also vllm/config/profiler.py and `--profiler-config` argument
|
||||
VLLM_TORCH_CUDA_PROFILE: str | None = None
|
||||
VLLM_TORCH_PROFILER_DIR: str | None = None
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
|
||||
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
|
||||
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
|
||||
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
|
||||
VLLM_PROFILER_DELAY_ITERS: str | None = None
|
||||
VLLM_PROFILER_MAX_ITERS: str | None = None
|
||||
# End of deprecated env variables for profiling
|
||||
VLLM_USE_AOT_COMPILE: bool = False
|
||||
VLLM_USE_BYTECODE_HOOK: bool = False
|
||||
VLLM_FORCE_AOT_LOAD: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||
VLLM_PROFILER_DELAY_ITERS: int = 0
|
||||
VLLM_PROFILER_MAX_ITERS: int = 0
|
||||
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
|
||||
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
@ -452,6 +456,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Main CUDA version of vLLM. This follows PyTorch but can be overridden.
|
||||
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
|
||||
or "12.9",
|
||||
# Controls PyTorch float32 matmul precision mode within vLLM workers.
|
||||
# Valid options mirror torch.set_float32_matmul_precision
|
||||
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
|
||||
"VLLM_FLOAT32_MATMUL_PRECISION",
|
||||
"highest",
|
||||
["highest", "high", "medium"],
|
||||
case_sensitive=False,
|
||||
),
|
||||
# Maximum number of compilation jobs to run in parallel.
|
||||
# By default this is the number of CPUs
|
||||
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
|
||||
@ -841,71 +853,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR", None
|
||||
),
|
||||
# Enables torch CUDA profiling if set.
|
||||
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered.
|
||||
"VLLM_TORCH_CUDA_PROFILE": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
|
||||
),
|
||||
# Enables torch CUDA profiling if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
|
||||
# Enables torch profiler if set.
|
||||
# Both AsyncLLM's CPU traces as well as workers'
|
||||
# traces (CPU & GPU) will be saved under this directory.
|
||||
# Note that it must be an absolute path.
|
||||
"VLLM_TORCH_PROFILER_DIR": lambda: (
|
||||
None
|
||||
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None
|
||||
else (
|
||||
val
|
||||
if val.startswith("gs://") and val[5:] and val[5] != "/"
|
||||
else os.path.abspath(os.path.expanduser(val))
|
||||
)
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
|
||||
# Enable torch profiler to record shapes if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
|
||||
),
|
||||
# Enable torch profiler to record shapes if set
|
||||
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
|
||||
# not record shapes.
|
||||
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"
|
||||
# Enable torch profiler to profile memory if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
|
||||
),
|
||||
# Enable torch profiler to profile memory if set
|
||||
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
|
||||
# will not profile memory.
|
||||
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"
|
||||
# Enable torch profiler to profile stack if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
|
||||
),
|
||||
# Enable torch profiler to profile stack if set
|
||||
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
|
||||
# profile stack by default.
|
||||
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
|
||||
# Enable torch profiler to profile flops if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
|
||||
),
|
||||
# Enable torch profiler to profile flops if set
|
||||
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
|
||||
# not profile flops.
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
|
||||
),
|
||||
# Disable torch profiling of the AsyncLLMEngine process.
|
||||
# If set to 1, will not profile the engine process.
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
|
||||
# Disable torch profiling of the AsyncLLMEngine process if set to 1.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
|
||||
),
|
||||
# Delay number of iterations before starting profiling when using
|
||||
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
|
||||
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
|
||||
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
|
||||
),
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
|
||||
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
|
||||
# If set to 0, will not limit the number of iterations.
|
||||
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
|
||||
"VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
|
||||
# Control whether torch profiler gzip-compresses profiling files.
|
||||
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default).
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
|
||||
),
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
|
||||
# Control whether torch profiler dumps the self_cuda_time_total table.
|
||||
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping
|
||||
# (enabled by default).
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0"
|
||||
# Set to 0 to disable dumping the table.
|
||||
# Deprecated, see profiler_config.
|
||||
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
|
||||
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
|
||||
),
|
||||
# If set, vLLM will use Triton implementations of AWQ.
|
||||
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
||||
|
||||
@ -292,7 +292,7 @@ def set_forward_context(
|
||||
if num_tokens_across_dp is None:
|
||||
assert ubatch_slices is None
|
||||
assert num_tokens is not None
|
||||
_, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
_, num_tokens_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
|
||||
@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
|
||||
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if current_platform.is_device_capability(100):
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
@ -49,6 +52,7 @@ __all__ = [
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
"activation_without_mul",
|
||||
"override_config",
|
||||
|
||||
@ -895,6 +895,48 @@ def get_moe_configs(
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_block_size_k_divisible(
|
||||
size_k: int, block_size_k: int, group_size: int
|
||||
) -> int:
|
||||
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
|
||||
|
||||
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
|
||||
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
|
||||
|
||||
Args:
|
||||
size_k: The size_k dimension that must be divisible by result.
|
||||
block_size_k: Preferred block size (will be adjusted if needed).
|
||||
group_size: The result must be divisible by this.
|
||||
|
||||
Returns:
|
||||
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
|
||||
"""
|
||||
# Fast path: already valid
|
||||
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
|
||||
return block_size_k
|
||||
|
||||
# Find the largest value that:
|
||||
# 1. Divides size_k (size_k % candidate == 0)
|
||||
# 2. Is divisible by group_size (candidate % group_size == 0)
|
||||
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
|
||||
#
|
||||
# Strategy: Search from min(block_size_k, size_k) down to group_size,
|
||||
# stepping by group_size to ensure divisibility by group_size
|
||||
max_search = min(block_size_k, size_k)
|
||||
start = (max_search // group_size) * group_size
|
||||
for candidate in range(start, group_size - 1, -group_size):
|
||||
if size_k % candidate == 0:
|
||||
return candidate
|
||||
|
||||
# Fallback: if group_size divides size_k, use it
|
||||
# This should always be true with correct group_size configuration
|
||||
if size_k % group_size == 0:
|
||||
return group_size
|
||||
|
||||
# This should not happen with correct group_size, but ensure divisibility
|
||||
return size_k
|
||||
|
||||
|
||||
def get_moe_wna16_block_config(
|
||||
config: dict[str, int],
|
||||
use_moe_wna16_cuda: bool,
|
||||
@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
|
||||
# at the same time.
|
||||
block_size_n = 1024
|
||||
|
||||
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
|
||||
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
|
||||
|
||||
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -100,22 +99,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -97,23 +96,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -127,10 +109,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.allow_inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
|
||||
@ -33,10 +33,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
@ -57,11 +53,8 @@ from vllm.utils.torch_utils import (
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import eplb_map_to_physical_and_record, fused_experts
|
||||
from .fused_moe import eplb_map_to_physical_and_record
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = object # type: ignore
|
||||
FusedMoEPrepareAndFinalize = object # type: ignore
|
||||
|
||||
def _eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
@ -483,7 +476,7 @@ class FusedMoE(CustomOp):
|
||||
enable_eplb=self.enable_eplb,
|
||||
)
|
||||
|
||||
self.expert_map: torch.Tensor | None
|
||||
self._expert_map: torch.Tensor | None
|
||||
local_num_experts, expert_map, expert_mask = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
@ -493,7 +486,7 @@ class FusedMoE(CustomOp):
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("_expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._maybe_init_expert_routing_tables()
|
||||
logger.info_once(
|
||||
@ -506,10 +499,10 @@ class FusedMoE(CustomOp):
|
||||
self.expert_placement_strategy,
|
||||
self.local_num_experts,
|
||||
self.global_num_experts,
|
||||
get_compressed_expert_map(self.expert_map),
|
||||
get_compressed_expert_map(self._expert_map),
|
||||
)
|
||||
else:
|
||||
self.local_num_experts, self.expert_map, self.expert_mask = (
|
||||
self.local_num_experts, self._expert_map, self.expert_mask = (
|
||||
self.global_num_experts,
|
||||
None,
|
||||
None,
|
||||
@ -781,7 +774,7 @@ class FusedMoE(CustomOp):
|
||||
),
|
||||
)
|
||||
|
||||
if self.expert_map is None:
|
||||
if self._expert_map is None:
|
||||
return None
|
||||
|
||||
routing_tables = self.ensure_round_robin_expert_routing_tables(
|
||||
@ -789,7 +782,7 @@ class FusedMoE(CustomOp):
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
local_num_experts=self.local_num_experts,
|
||||
device=self.expert_map.device,
|
||||
device=self._expert_map.device,
|
||||
)
|
||||
|
||||
global_to_physical, physical_to_global, local_global = routing_tables
|
||||
@ -840,8 +833,8 @@ class FusedMoE(CustomOp):
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
assert self.expert_map is not None
|
||||
with self.expert_map.device:
|
||||
assert self._expert_map is not None
|
||||
with self._expert_map.device:
|
||||
local_num_experts, expert_map, expert_mask = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
@ -851,7 +844,7 @@ class FusedMoE(CustomOp):
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("_expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._maybe_init_expert_routing_tables()
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
@ -888,7 +881,7 @@ class FusedMoE(CustomOp):
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||
# NOTE: We don't need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
@ -1068,9 +1061,9 @@ class FusedMoE(CustomOp):
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||
if self.expert_map is None:
|
||||
if self._expert_map is None:
|
||||
return expert_id
|
||||
return self.expert_map[expert_id].item()
|
||||
return self._expert_map[expert_id].item()
|
||||
|
||||
def _init_aiter_shared_experts_topK_buffer(
|
||||
self, vllm_config: VllmConfig, dp_size: int
|
||||
@ -1744,6 +1737,12 @@ class FusedMoE(CustomOp):
|
||||
reduce_output(fused_output)[..., :og_hidden_states],
|
||||
)
|
||||
|
||||
@property
|
||||
def expert_map(self) -> torch.Tensor | None:
|
||||
return (
|
||||
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1805,24 +1804,6 @@ class FusedMoE(CustomOp):
|
||||
layer=self,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
@ -1968,25 +1949,6 @@ class FusedMoE(CustomOp):
|
||||
if do_naive_dispatch_combine
|
||||
else hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -269,53 +268,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@ -333,24 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -364,9 +307,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
elif self.flashinfer_cutlass_moe_enabled:
|
||||
return self.flashinfer_cutlass_moe(
|
||||
@ -375,8 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = fused_experts(
|
||||
@ -386,11 +329,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
@ -405,148 +348,101 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for CPU.")
|
||||
|
||||
return layer.cpu_fused_moe(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.global_num_experts,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.scoring_func,
|
||||
layer.routed_scaling_factor,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.activation,
|
||||
)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for XPU.")
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
assert not layer.use_grouped_topk
|
||||
assert layer.num_expert_group is None
|
||||
assert layer.topk_group is None
|
||||
assert layer.custom_routing_function is None
|
||||
assert layer.apply_router_weight_on_input is False
|
||||
if layer.scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for TPU."
|
||||
)
|
||||
if e_score_correction_bias is not None:
|
||||
if layer.e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for TPU."
|
||||
)
|
||||
assert activation == "silu", f"{activation} is not supported for TPU."
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} is not supported for TPU."
|
||||
)
|
||||
assert layer.routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {layer.routed_scaling_factor} is "
|
||||
"not supported for TPU."
|
||||
)
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
or logical_to_physical_map is not None
|
||||
or logical_replica_count is not None
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for TPU.")
|
||||
return fused_moe_pallas(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk=top_k,
|
||||
topk=layer.top_k,
|
||||
gating_output=router_logits,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
renormalize=layer.renormalize,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -669,25 +668,8 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -708,9 +690,9 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
@ -498,23 +497,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -534,10 +516,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@ -558,31 +557,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if enable_eplb:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
|
||||
)
|
||||
@ -591,12 +573,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
top_k=layer.top_k,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
@ -619,9 +601,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
@ -646,15 +628,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
assert expert_map is None, (
|
||||
assert layer.expert_map is None, (
|
||||
"Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod."
|
||||
@ -670,7 +652,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# TODO(bnell): derive these from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
@ -1188,23 +1170,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -1215,7 +1180,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1228,9 +1195,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
@ -1248,9 +1215,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -1270,10 +1237,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=None
|
||||
if self.disable_expert_map
|
||||
else layer.expert_map, # ???
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
@ -1290,9 +1259,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
@ -1314,10 +1283,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -1437,23 +1406,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -1469,10 +1421,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -1814,25 +1766,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -1853,9 +1790,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
@ -2057,23 +1994,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -2089,10 +2009,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -2372,32 +2292,15 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert activation in ("silu", "swigluoai", "swiglu"), (
|
||||
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert layer.activation in ("silu", "swigluoai", "swiglu"), (
|
||||
"Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||
)
|
||||
assert expert_map is None, """expert_map/EP not implemented
|
||||
assert layer.expert_map is None, """expert_map/EP not implemented
|
||||
for CPU dyn-4bit MoE."""
|
||||
|
||||
def _act_kind(s: str) -> int:
|
||||
@ -2414,15 +2317,9 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
top_k=layer.top_k,
|
||||
use_grouped_topk=layer.use_grouped_topk,
|
||||
renormalize=layer.renormalize,
|
||||
)
|
||||
|
||||
return torch.ops._C.dynamic_4bit_int_moe(
|
||||
@ -2435,8 +2332,8 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_in_features,
|
||||
layer.w13_out_features,
|
||||
layer.group_size,
|
||||
apply_router_weight_on_input,
|
||||
int(_act_kind(activation)),
|
||||
layer.apply_router_weight_on_input,
|
||||
int(_act_kind(layer.activation)),
|
||||
)
|
||||
|
||||
|
||||
@ -2707,28 +2604,11 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
):
|
||||
if enable_eplb:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
||||
)
|
||||
@ -2749,9 +2629,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
a_strides1=self.a_strides1_c_strides2,
|
||||
a_strides2=self.a_strides2,
|
||||
b_strides1=self.b_strides1,
|
||||
|
||||
@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
# don't restrict as emulations
|
||||
return 80
|
||||
|
||||
def create_weights(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@ -140,23 +139,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -172,10 +154,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@ -99,7 +98,7 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
@ -559,46 +558,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
assert not self.act_q_static
|
||||
size_k_first = False
|
||||
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
)
|
||||
# Delete the weight_scale_inv parameter to avoid confusion
|
||||
# with the weight_scale parameter
|
||||
del layer.weight_scale_inv
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
layer.input_scale = (
|
||||
Parameter(input_scale, requires_grad=False)
|
||||
if input_scale is not None
|
||||
else None
|
||||
)
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
if input_scale is not None:
|
||||
replace_parameter(layer, "input_scale", input_scale)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(
|
||||
@ -625,7 +628,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@ -654,10 +657,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
if self.block_quant:
|
||||
weight_scale = layer.weight_scale_inv
|
||||
else:
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
@ -671,7 +679,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@ -941,22 +949,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
@ -994,13 +998,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
@ -1009,16 +1014,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
@ -1039,12 +1045,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
@ -1058,22 +1060,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
@ -1097,12 +1091,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
@ -1245,41 +1237,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
e_score_correction_bias = (
|
||||
e_score_correction_bias.to(x.dtype)
|
||||
if e_score_correction_bias is not None
|
||||
layer.e_score_correction_bias.to(x.dtype)
|
||||
if layer.e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
@ -1293,29 +1264,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routed_scaling=routed_scaling_factor,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
result = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
select_result = layer.select_experts(
|
||||
@ -1336,13 +1309,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1355,20 +1330,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
if not self.block_quant:
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
assert (
|
||||
not layer.renormalize and layer.custom_routing_function is not None
|
||||
)
|
||||
assert layer.scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
|
||||
)
|
||||
# Delegate to CUTLASS FlashInfer path; function already bound with
|
||||
# use_deepseek_fp8_block_scale for block-quant when applicable
|
||||
@ -1378,10 +1355,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -1393,10 +1370,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -625,26 +625,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused GGUF MoE method."
|
||||
@ -662,7 +645,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type,
|
||||
activation,
|
||||
layer.activation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -790,25 +789,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -829,9 +811,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
|
||||
super().__init__()
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
assert not static, "Group quantization only supports dynamic mode"
|
||||
@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
use_aiter_quant = (
|
||||
not self.is_group_quant
|
||||
and self.use_aiter
|
||||
and scale_ub is None
|
||||
and x.is_contiguous()
|
||||
)
|
||||
use_aiter_per_tensor_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
|
||||
)
|
||||
use_aiter_per_token_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
if use_aiter_per_tensor_quant:
|
||||
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
|
||||
if use_aiter_per_token_quant:
|
||||
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
|
||||
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@ -440,31 +439,14 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
)
|
||||
|
||||
@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# skip if there are no weights to process (for example, weight reloading)
|
||||
if not hasattr(layer, "q_scale"):
|
||||
assert not hasattr(layer, "k_scale")
|
||||
assert not hasattr(layer, "v_scale")
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
return
|
||||
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@ -706,43 +705,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
|
||||
)
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
assert layer.activation == "silu", (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
assert not renormalize
|
||||
|
||||
assert not layer.renormalize
|
||||
return apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# Expert selection
|
||||
@ -752,9 +735,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert activation in ("silu", "relu2_no_mul"), (
|
||||
assert layer.activation in ("silu", "relu2_no_mul"), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {activation}"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
@ -762,10 +745,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
@ -779,11 +762,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
@ -1503,23 +1486,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.moe.is_act_and_mul:
|
||||
assert (
|
||||
@ -1534,7 +1500,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
if enable_eplb:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
|
||||
)
|
||||
@ -1542,12 +1508,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
top_k=layer.top_k,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
@ -1570,9 +1536,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
@ -1603,10 +1569,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
@ -1621,8 +1587,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# TODO: derive from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@ -60,7 +59,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
|
||||
elif self.linear_quant_method == "awq":
|
||||
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
-1 if capability_tuple is None else capability_tuple.to_int()
|
||||
@ -107,7 +106,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
if linear_quant_method == "gptq":
|
||||
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method == "awq":
|
||||
elif linear_quant_method in ("awq", "awq_marlin"):
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
@ -184,7 +183,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return GPTQConfig.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
elif self.linear_quant_method == "awq":
|
||||
elif self.linear_quant_method in ("awq", "awq_marlin"):
|
||||
if self.use_marlin and check_marlin_supports_layer(
|
||||
layer, self.group_size
|
||||
):
|
||||
@ -362,27 +361,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -395,9 +377,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -468,7 +450,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
# awq_marlin uses the same weight format as awq
|
||||
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
scoring_func,
|
||||
activation,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
layer.use_grouped_topk,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.scoring_func,
|
||||
layer.activation,
|
||||
layer.expert_load_view,
|
||||
layer.logical_to_physical_map,
|
||||
layer.logical_replica_count,
|
||||
), "MXFP4 are not supported with this configuration."
|
||||
|
||||
if (
|
||||
@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
global_num_experts,
|
||||
top_k,
|
||||
layer.global_num_experts,
|
||||
layer.top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
None,
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
1 if layer.renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "swigluoai", (
|
||||
assert layer.activation == "swigluoai", (
|
||||
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
|
||||
)
|
||||
hidden_size_pad = round_up(self.original_hidden_size, 128)
|
||||
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
|
||||
hidden_states = layer.ipex_fusion(
|
||||
x_pad,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
activation="swiglu_oai",
|
||||
)
|
||||
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -631,9 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@ -645,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_config.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
|
||||
@ -27,10 +27,10 @@ from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
@ -195,6 +195,39 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _triton_per_token_group_quant_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return per_token_group_quant_fp8(
|
||||
x, group_size, column_major_scales=False, use_ue8m0=False
|
||||
)
|
||||
|
||||
|
||||
def _triton_per_token_group_quant_fp8_fake(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
(N + group_size - 1) // group_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"triton_per_token_group_quant_fp8",
|
||||
_triton_per_token_group_quant_fp8_impl,
|
||||
fake_impl=_triton_per_token_group_quant_fp8_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class W8A8BlockFp8LinearOp:
|
||||
@ -214,6 +247,7 @@ class W8A8BlockFp8LinearOp:
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.is_blackwell = current_platform.is_device_capability(100)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
@ -269,7 +303,7 @@ class W8A8BlockFp8LinearOp:
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
|
||||
if self.use_deep_gemm_e8m0 and self.is_blackwell:
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
@ -340,17 +374,15 @@ class W8A8BlockFp8LinearOp:
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
# MI350 case uses triton kernel
|
||||
elif use_triton:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
|
||||
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(
|
||||
input_2d, self.act_quant_group_shape.col
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
q_input,
|
||||
@ -1404,12 +1436,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
if should_use_deepgemm:
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.weight.data,
|
||||
ws=layer.weight_scale.data,
|
||||
ws=layer.weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
|
||||
replace_parameter(layer, "weight", dg_weight)
|
||||
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_quant_input,
|
||||
should_use_atomic_add_reduce,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
replace_parameter(layer, "weight", marlin_qweight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
elif "weight_scale_inv" in dir(layer):
|
||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||
del layer.weight_scale_inv
|
||||
|
||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||
|
||||
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
|
||||
)
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
if hasattr(layer, "weight_scale"):
|
||||
replace_parameter(layer, "weight_scale", marlin_scales)
|
||||
elif hasattr(layer, "weight_scale_inv"):
|
||||
replace_parameter(layer, "weight_scale_inv", marlin_scales)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n,)
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
replace_parameter(layer, "bias", bias)
|
||||
|
||||
|
||||
def prepare_moe_fp8_layer_for_marlin(
|
||||
|
||||
@ -95,8 +95,11 @@ def requantize_with_max_scale(
|
||||
# from disk in this case. Skip requantization in this case (since)
|
||||
# we already are quantized with the single scale.
|
||||
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
#
|
||||
# Extra note: upon weight reloading weight_scale.ndim == 0
|
||||
unfused_module_in_checkpoint = (
|
||||
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||
weight_scale.ndim != 0
|
||||
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||
)
|
||||
|
||||
# If unfused checkpoint, need requanize with the single scale.
|
||||
|
||||
@ -367,6 +367,8 @@ class Qwen2MoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@ -512,6 +514,12 @@ class Qwen2MoeModel(nn.Module):
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
# GGUF: make sure that shared_expert_gate is a 2D tensor.
|
||||
if (
|
||||
"mlp.shared_expert_gate" in name
|
||||
and len(loaded_weight.shape) == 1
|
||||
):
|
||||
loaded_weight = loaded_weight[None, :]
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert loaded_weight.numel() == 1, (
|
||||
f"KV scale numel {loaded_weight.numel()} != 1"
|
||||
)
|
||||
loaded_weight = loaded_weight.squeeze()
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
|
||||
@ -50,6 +50,31 @@ def set_weight_attrs(
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
|
||||
"""
|
||||
Replace a parameter of a layer while maintaining the ability to reload the weight.
|
||||
Called within implementations of the `process_weights_after_loading` method.
|
||||
|
||||
This function should not be called on weights which are tied/shared
|
||||
|
||||
Args:
|
||||
layer: Layer containing parameter to replace
|
||||
param_name: Name of parameter to replace
|
||||
new_data: New data of the new parameter
|
||||
"""
|
||||
# should not be used on a tied/shared param
|
||||
if isinstance(new_data, torch.nn.Parameter):
|
||||
new_data = new_data.data
|
||||
new_param = torch.nn.Parameter(new_data, requires_grad=False)
|
||||
|
||||
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
|
||||
if old_param is not None and hasattr(old_param, "weight_loader"):
|
||||
weight_loader = old_param.weight_loader
|
||||
set_weight_attrs(new_param, {"weight_loader": weight_loader})
|
||||
|
||||
setattr(layer, param_name, new_param)
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||
|
||||
@ -381,6 +381,8 @@ class RocmPlatform(Platform):
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
@ -400,8 +402,6 @@ class RocmPlatform(Platform):
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
@ -415,6 +415,9 @@ class RocmPlatform(Platform):
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
|
||||
compilation_config.custom_ops.append("+quant_fp8")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user